Files
PrimAITE/src/primaite/main.py

380 lines
13 KiB
Python

# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""
The main PrimAITE session runner module.
TODO: This will eventually be refactored out into a proper Session class.
TODO: The passing about of session_dir and timestamp_str is temporary and
will be cleaned up once we move to a proper Session class.
"""
import argparse
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Final, Union
from uuid import uuid4
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = getLogger(__name__)
def run_generic(env: Primaite, config_values: TrainingConfig):
"""
Run against a generic agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
"""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
env.close()
def run_stable_baselines3_ppo(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 PPO agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = PPO.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(
PPOMlp,
env,
verbose=0,
n_steps=config_values.num_steps,
seed=env.training_config.seed,
)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
for episode in range(0, config_values.num_episodes):
obs = env.reset()
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs, deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = env.step(action)
env.close()
def run_stable_baselines3_a2c(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 A2C agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = A2C.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C(
"MlpPolicy",
env,
verbose=0,
n_steps=config_values.num_steps,
seed=env.training_config.seed,
)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
for episode in range(0, config_values.num_episodes):
obs = env.reset()
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs, deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = env.step(action)
env.close()
def _write_session_metadata_file(
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": uuid,
"start_datetime": session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": env.training_config.to_dict(json_serializable=True),
"lay_down_config": env.lay_down_config,
},
}
filepath = session_dir / "session_metadata.json"
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(session_dir: Path, env: Primaite):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(session_dir / "session_metadata.json", "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = env.episode_count
metadata_dict["total_time_steps"] = env.total_step_count
filepath = session_dir / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
with open(filepath, "w") as file:
json.dump(metadata_dict, file)
def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str):
"""
Persist an agent.
Only works for stable baselines3 agents at present.
:param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if not isinstance(agent, OnPolicyAlgorithm):
msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}."
_LOGGER.error(msg)
else:
filepath = session_path / f"agent_saved_{timestamp_str}"
agent.save(filepath)
_LOGGER.debug(f"Trained agent saved as: {filepath}")
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]):
"""Run the PrimAITE Session.
:param training_config_path: The training config filepath.
:param lay_down_config_path: The lay down config filepath.
"""
# Welcome message
print("Welcome to the Primary-level AI Training Environment (PrimAITE)")
uuid = str(uuid4())
session_timestamp: Final[datetime] = datetime.now()
session_dir = _get_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
print(f"The output directory for this session is: {session_dir}")
# Create a list of transactions
# A transaction is an object holding the:
# - episode #
# - step #
# - initial observation space
# - action
# - reward
# - new observation space
transaction_list = []
# Create the Primaite environment
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Writing Session Metadata file...")
_write_session_metadata_file(
session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env
)
config_values = env.training_config
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = env.episode_steps
# Run environment against an agent
if config_values.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif config_values.agent_identifier == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif config_values.agent_identifier == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tc")
parser.add_argument("--ldc")
args = parser.parse_args()
if not args.tc:
_LOGGER.error(
"Please provide a training config file using the --tc " "argument"
)
if not args.ldc:
_LOGGER.error(
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)