#917 - Almost there. All output files being writen for SB3/RLLIB PPO & A2C. Just need to bring in the hardcoded agents then update the testa and docs.

This commit is contained in:
Chris McCarthy
2023-06-19 21:53:25 +01:00
parent 23bafde457
commit 03ae4884e0
14 changed files with 321 additions and 328 deletions

View File

@@ -1 +1 @@
2.0.0rc1
2.0.0b1

View File

@@ -5,19 +5,18 @@ from pathlib import Path
from typing import Optional, Final, Dict, Union, List
from uuid import uuid4
from primaite import getLogger
from primaite import getLogger, SESSIONS_DIR
from primaite.common.enums import OutputVerboseLevel
from primaite.config import lay_down_config
from primaite.config import training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = getLogger(__name__)
def _get_temp_session_path(session_timestamp: datetime) -> Path:
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
@@ -26,7 +25,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path("./") / date_dir / session_path
session_path = SESSIONS_DIR / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
@@ -57,16 +56,16 @@ class AgentSessionABC(ABC):
self._env: Primaite
self._agent = None
self._transaction_list: List[Dict] = []
self._can_learn: bool = False
self._can_evaluate: bool = False
self._uuid = str(uuid4())
self.session_timestamp: datetime = datetime.now()
"The session timestamp"
self.session_path = _get_temp_session_path(self.session_timestamp)
self.session_path = _get_session_path(self.session_timestamp)
"The Session path"
self.checkpoints_path = self.session_path / "checkpoints"
self.checkpoints_path.mkdir(parents=True, exist_ok=True)
"The Session checkpoints path"
self.timestamp_str = self.session_timestamp.strftime(
@@ -167,11 +166,6 @@ class AgentSessionABC(ABC):
):
if self._can_learn:
_LOGGER.debug("Writing transactions")
write_transaction_to_file(
transaction_list=self._transaction_list,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._update_session_metadata_file()
self._can_evaluate = True

View File

@@ -1,18 +1,23 @@
import glob
import time
from enum import Enum
import json
from datetime import datetime
from pathlib import Path
from typing import Union, Optional
from pathlib import Path
from typing import Optional
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.a2c import A2CConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.config import training_config
from primaite.common.enums import AgentFramework, RedAgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _env_creator(env_config):
return Primaite(
training_config_path=env_config["training_config_path"],
@@ -23,7 +28,17 @@ def _env_creator(env_config):
)
class RLlibPPO(AgentSessionABC):
def _custom_log_creator(session_path: Path):
logdir = session_path / "ray_results"
logdir.mkdir(parents=True, exist_ok=True)
def logger_creator(config):
return UnifiedLogger(config, logdir, loggers=None)
return logger_creator
class RLlibAgent(AgentSessionABC):
def __init__(
self,
@@ -31,17 +46,63 @@ class RLlibPPO(AgentSessionABC):
lay_down_config_path
):
super().__init__(training_config_path, lay_down_config_path)
self._ppo_config: PPOConfig
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = (f"Expected RLLIB agent_framework, "
f"got {self._training_config.agent_framework}")
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
self._agent_config_class = PPOConfig
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = ("Expected PPO or A2C red_agent_identifier, "
f"got {self._training_config.red_agent_identifier.value}")
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: PPOConfig
self._current_result: dict
self._setup()
self._agent.save()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"red_agent_identifier="
f"{self._training_config.red_agent_identifier}, "
f"deep_learning_framework="
f"{self._training_config.deep_learning_framework}"
)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_path`` directory
with the following key/value pairs:
- end_datetime: The date & time the session ended in iso format.
- total_episodes: The total number of training episodes completed.
- total_time_steps: The total number of training time steps completed.
"""
with open(self.session_path / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
_LOGGER.debug("Finished updating session metadata file")
def _setup(self):
super()._setup()
register_env("primaite", _env_creator)
self._ppo_config = PPOConfig()
self._agent_config = self._agent_config_class()
self._ppo_config.environment(
self._agent_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
@@ -52,19 +113,21 @@ class RLlibPPO(AgentSessionABC):
)
)
self._ppo_config.training(
self._agent_config.training(
train_batch_size=self._training_config.num_steps
)
self._ppo_config.framework(
framework=self._training_config.deep_learning_framework.value
self._agent_config.framework(
framework=self._training_config.deep_learning_framework
)
self._ppo_config.rollouts(
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_steps
)
self._agent: Algorithm = self._ppo_config.build()
self._agent: Algorithm = self._agent_config.build(
logger_creator=_custom_log_creator(self.session_path)
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
@@ -84,8 +147,8 @@ class RLlibPPO(AgentSessionABC):
):
# Temporarily override train_batch_size and horizon
if time_steps:
self._ppo_config.train_batch_size = time_steps
self._ppo_config.horizon = time_steps
self._agent_config.train_batch_size = time_steps
self._agent_config.horizon = time_steps
if not episodes:
episodes = self._training_config.num_episodes

View File

@@ -1,23 +1,48 @@
from typing import Optional
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3 import PPO, A2C
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import RedAgentIdentifier, AgentFramework
from primaite.environment.primaite_env import Primaite
from stable_baselines3.ppo import MlpPolicy as PPOMlp
_LOGGER = getLogger(__name__)
class SB3PPO(AgentSessionABC):
class SB3Agent(AgentSessionABC):
def __init__(
self,
training_config_path,
lay_down_config_path
):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = (f"Expected SB3 agent_framework, "
f"got {self._training_config.agent_framework}")
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = ("Expected PPO or A2C red_agent_identifier, "
f"got {self._training_config.red_agent_identifier.value}")
_LOGGER.error(msg)
raise ValueError(msg)
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
_LOGGER.debug(
f"Created {self.__class__.__name__} using: "
f"agent_framework={self._training_config.agent_framework}, "
f"red_agent_identifier="
f"{self._training_config.red_agent_identifier}"
)
def _setup(self):
super()._setup()
@@ -28,10 +53,10 @@ class SB3PPO(AgentSessionABC):
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
self._agent = PPO(
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=1,
verbose=self._training_config.output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path
)
@@ -65,6 +90,7 @@ class SB3PPO(AgentSessionABC):
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.close()
super().learn()
def evaluate(

View File

@@ -80,13 +80,13 @@ class Protocol(Enum):
class SessionType(Enum):
"The type of PrimAITE Session to be run."
"""The type of PrimAITE Session to be run."""
TRAINING = 1
EVALUATION = 2
BOTH = 3
class VerboseLevel(Enum):
class VerboseLevel(IntEnum):
"""PrimAITE Session Output verbose level."""
NO_OUTPUT = 0
INFO = 1

View File

@@ -1,7 +1,3 @@
- item_type: ACTIONS
type: NODE
- item_type: STEPS
steps: 128
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,7 +1,3 @@
- item_type: ACTIONS
type: NODE
- item_type: STEPS
steps: 128
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -1,7 +1,3 @@
- item_type: ACTIONS
type: NODE
- item_type: STEPS
steps: 256
- item_type: PORTS
ports_list:
- port: '80'

View File

@@ -59,7 +59,7 @@ observation_space_high_value: 1000000000
# "NONE" (No Output)
# "INFO" (Info Messages)
# "ALL" (All Messages)
output_verbose_level: INFO
output_verbose_level: NONE
# Reward values
# Generic

View File

@@ -185,7 +185,6 @@ class TrainingConfig:
for field, enum_class in field_enum_map.items():
if field in config_dict:
config_dict[field] = enum_class[config_dict[field]]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
@@ -203,6 +202,7 @@ class TrainingConfig:
data["red_agent_identifier"] = self.red_agent_identifier.value
data["action_type"] = self.action_type.value
data["output_verbose_level"] = self.output_verbose_level.value
data["session_type"] = self.session_type.value
return data

View File

@@ -45,6 +45,8 @@ from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
@@ -407,10 +409,19 @@ class Primaite(Env):
# Return
return self.env_obs, reward, done, self.step_info
def __close__(self):
"""Override close function."""
self.csv_file.close()
def close(self):
self.__close__()
def __close__(self):
"""
Override close function
"""
write_transaction_to_file(
self.transaction_list,
self.session_path,
self.timestamp_str
)
self.csv_file.close()
def init_acl(self):
"""Initialise the Access Control List."""
self.acl.remove_all_rules()

View File

@@ -1,229 +1,162 @@
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# """
# The main PrimAITE session runner module.
#
# TODO: This will eventually be refactored out into a proper Session class.
# TODO: The passing about of session_path and timestamp_str is temporary and
# will be cleaned up once we move to a proper Session class.
# """
# import argparse
# import json
# import time
# from datetime import datetime
# from pathlib import Path
# from typing import Final, Union
# from uuid import uuid4
#
# from stable_baselines3 import A2C, PPO
# from stable_baselines3.common.evaluation import evaluate_policy
# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
# from stable_baselines3.ppo import MlpPolicy as PPOMlp
#
# from primaite import SESSIONS_DIR, getLogger
# from primaite.config.training_config import TrainingConfig
# from primaite.environment.primaite_env import Primaite
# from primaite.transactions.transactions_to_file import \
# write_transaction_to_file
#
# _LOGGER = getLogger(__name__)
#
#
# def run_generic(env: Primaite, config_values: TrainingConfig):
# """
# Run against a generic agent.
#
# :param env: An instance of
# :class:`~primaite.environment.primaite_env.Primaite`.
# :param config_values: An instance of
# :class:`~primaite.config.training_config.TrainingConfig`.
# """
# for episode in range(0, config_values.num_episodes):
# env.reset()
# for step in range(0, config_values.num_steps):
# # Send the observation space to the agent to get an action
# # TEMP - random action for now
# # action = env.blue_agent_action(obs)
# action = env.action_space.sample()
#
# # Run the simulation step on the live environment
# obs, reward, done, info = env.step(action)
#
# # Break if done is True
# if done:
# break
#
# # Introduce a delay between steps
# time.sleep(config_values.time_delay / 1000)
#
# # Reset the environment at the end of the episode
#
# env.close()
#
#
# def run_stable_baselines3_ppo(
# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
# ):
# """
# Run against a stable_baselines3 PPO agent.
#
# :param env: An instance of
# :class:`~primaite.environment.primaite_env.Primaite`.
# :param config_values: An instance of
# :class:`~primaite.config.training_config.TrainingConfig`.
# :param session_path: The directory path the session is writing to.
# :param timestamp_str: The session timestamp in the format:
# <yyyy-mm-dd>_<hh-mm-ss>.
# """
# if config_values.load_agent:
# try:
# agent = PPO.load(
# config_values.agent_load_file,
# env,
# verbose=0,
# n_steps=config_values.num_steps,
# )
# except Exception:
# print(
# "ERROR: Could not load agent at location: "
# + config_values.agent_load_file
# )
# _LOGGER.error("Could not load agent")
# _LOGGER.error("Exception occured", exc_info=True)
# else:
# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
#
# if config_values.session_type == "TRAINING":
# # We're in a training session
# print("Starting training session...")
# _LOGGER.debug("Starting training session...")
# for episode in range(config_values.num_episodes):
# agent.learn(total_timesteps=config_values.num_steps)
# _save_agent(agent, session_path, timestamp_str)
# else:
# # Default to being in an evaluation session
# print("Starting evaluation session...")
# _LOGGER.debug("Starting evaluation session...")
# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
#
# env.close()
#
#
#
#
# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
# """
# Persist an agent.
#
# Only works for stable baselines3 agents at present.
#
# :param session_path: The directory path the session is writing to.
# :param timestamp_str: The session timestamp in the format:
# <yyyy-mm-dd>_<hh-mm-ss>.
# """
# if not isinstance(agent, OnPolicyAlgorithm):
# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
# _LOGGER.error(msg)
# else:
# filepath = session_path / f"agent_saved_{timestamp_str}"
# agent.save(filepath)
# _LOGGER.debug(f"Trained agent saved as: {filepath}")
#
#
#
#
# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
# """Run the PrimAITE Session.
#
# :param training_config_path: The training config filepath.
# :param lay_down_config_path: The lay down config filepath.
# """
# # Welcome message
# print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
# uuid = str(uuid4())
# session_timestamp: Final[datetime] = datetime.now()
# session_path = _get_session_path(session_timestamp)
# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
#
# print(f"The output directory for this session is: {session_path}")
#
# # Create a list of transactions
# # A transaction is an object holding the:
# # - episode #
# # - step #
# # - initial observation space
# # - action
# # - reward
# # - new observation space
# transaction_list = []
#
# # Create the Primaite environment
# env = Primaite(
# training_config_path=training_config_path,
# lay_down_config_path=lay_down_config_path,
# transaction_list=transaction_list,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
#
# print("Writing Session Metadata file...")
#
# _write_session_metadata_file(
# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env
# )
#
# config_values = env.training_config
#
# # Get the number of steps (which is stored in the child config file)
# config_values.num_steps = env.episode_steps
#
# # Run environment against an agent
# if config_values.agent_identifier == "GENERIC":
# run_generic(env=env, config_values=config_values)
# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
# run_stable_baselines3_ppo(
# env=env,
# config_values=config_values,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
# run_stable_baselines3_a2c(
# env=env,
# config_values=config_values,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
#
# print("Session finished")
# _LOGGER.debug("Session finished")
#
# print("Saving transaction logs...")
# write_transaction_to_file(
# transaction_list=transaction_list,
# session_path=session_path,
# timestamp_str=timestamp_str,
# )
#
# print("Updating Session Metadata file...")
# _update_session_metadata_file(session_path=session_path, env=env)
#
# print("Finished")
# _LOGGER.debug("Finished")
#
#
# if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument("--tc")
# parser.add_argument("--ldc")
# args = parser.parse_args()
# if not args.tc:
# _LOGGER.error(
# "Please provide a training config file using the --tc " "argument"
# )
# if not args.ldc:
# _LOGGER.error(
# "Please provide a lay down config file using the --ldc " "argument"
# )
# run(training_config_path=args.tc, lay_down_config_path=args.ldc)
#
#
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""
The main PrimAITE session runner module.
TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_path and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class.
"""
import argparse
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Final, Union
from uuid import uuid4
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
_LOGGER = getLogger(__name__)
def run_generic(env: Primaite, config_values: TrainingConfig):
"""
Run against a generic agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
"""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.close()
def run_stable_baselines3_ppo(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 PPO agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = PPO.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
"""
Persist an agent.
Only works for stable baselines3 agents at present.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if not isinstance(agent, OnPolicyAlgorithm):
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
_LOGGER.error(msg)
else:
filepath = session_path / f"agent_saved_{timestamp_str}"
agent.save(filepath)
_LOGGER.debug(f"Trained agent saved as: {filepath}")
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
"""Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
"""
session = PrimaiteSession(training_config_path, lay_down_config_path)
session.setup()
session.learn()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
args = parser.parse_args()
if not args.tc:
_LOGGER.error(
"Please provide a training config file using the --tc " "argument"
)
if not args.ldc:
_LOGGER.error(
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -8,10 +8,10 @@ from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR
from primaite.agents.agent import AgentSessionABC
from primaite.agents.rllib import RLlibPPO
from primaite.agents.sb3 import SB3PPO
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
ActionType
ActionType, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
@@ -95,35 +95,19 @@ class PrimaiteSession:
pass
elif self._training_config.agent_framework == AgentFramework.SB3:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Stable Baselines3/Proximal Policy Optimization
self._agent_session = SB3PPO(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Stable Baselines3/Advantage Actor Critic
raise NotImplementedError
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
# Stable Baselines3 Agent
self._agent_session = SB3Agent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Ray RLlib/Proximal Policy Optimization
self._agent_session = RLlibPPO(
self._training_config_path,
self._lay_down_config_path
)
# Ray RLlib Agent
self._agent_session = RLlibAgent(
self._training_config_path,
self._lay_down_config_path
)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Ray RLlib/Advantage Actor Critic
raise NotImplementedError
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
else:
# Invalid AgentFramework
pass
@@ -134,7 +118,8 @@ class PrimaiteSession:
episodes: Optional[int] = None,
**kwargs
):
self._agent_session.learn(time_steps, episodes, **kwargs)
if not self._training_config.session_type == SessionType.EVALUATION:
self._agent_session.learn(time_steps, episodes, **kwargs)
def evaluate(
self,
@@ -142,18 +127,5 @@ class PrimaiteSession:
episodes: Optional[int] = None,
**kwargs
):
self._agent_session.evaluate(time_steps, episodes, **kwargs)
@classmethod
def import_agent(
cls,
gent_path: str,
training_config_path: str,
lay_down_config_path: str
) -> PrimaiteSession:
session = PrimaiteSession(training_config_path, lay_down_config_path)
# Reset the UUID
session._uuid = ""
return session
if not self._training_config.session_type == SessionType.TRAINING:
self._agent_session.evaluate(time_steps, episodes, **kwargs)

View File

@@ -4,6 +4,8 @@
import csv
from pathlib import Path
import numpy as np
from primaite import getLogger
_LOGGER = getLogger(__name__)
@@ -54,8 +56,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
# Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action
# space as "AS_1"
# This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0]
action_length = template_transation.action_space.size
if isinstance(template_transation.action_space, int):
action_length = template_transation.action_space
else:
action_length = template_transation.action_space.size
obs_shape = template_transation.obs_space_post.shape
obs_assets = template_transation.obs_space_post.shape[0]
if len(obs_shape) == 1: