Files
PrimAITE/src/primaite/agents/sb3.py

103 lines
3.1 KiB
Python

from typing import Optional
import numpy as np
from stable_baselines3 import PPO
from primaite.agents.agent import AgentSessionABC
from primaite.environment.primaite_env import Primaite
from stable_baselines3.ppo import MlpPolicy as PPOMlp
class SB3PPO(AgentSessionABC):
def __init__(
self,
training_config_path,
lay_down_config_path
):
super().__init__(training_config_path, lay_down_config_path)
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
self._setup()
def _setup(self):
super()._setup()
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
self._agent = PPO(
PPOMlp,
self._env,
verbose=1,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
):
self._agent.save(
self.checkpoints_path / f"sb3ppo_{episode_count}.zip")
def _get_latest_checkpoint(self):
pass
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs
):
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True,
**kwargs
):
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(
obs,
deterministic=deterministic
)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
def load(self):
raise NotImplementedError
def save(self):
raise NotImplementedError
def export(self):
raise NotImplementedError