103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
from typing import Optional
|
|
|
|
import numpy as np
|
|
from stable_baselines3 import PPO
|
|
|
|
from primaite.agents.agent import AgentSessionABC
|
|
from primaite.environment.primaite_env import Primaite
|
|
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
|
|
|
|
|
class SB3PPO(AgentSessionABC):
|
|
def __init__(
|
|
self,
|
|
training_config_path,
|
|
lay_down_config_path
|
|
):
|
|
super().__init__(training_config_path, lay_down_config_path)
|
|
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
|
|
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
super()._setup()
|
|
self._env = Primaite(
|
|
training_config_path=self._training_config_path,
|
|
lay_down_config_path=self._lay_down_config_path,
|
|
transaction_list=[],
|
|
session_path=self.session_path,
|
|
timestamp_str=self.timestamp_str
|
|
)
|
|
self._agent = PPO(
|
|
PPOMlp,
|
|
self._env,
|
|
verbose=1,
|
|
n_steps=self._training_config.num_steps,
|
|
tensorboard_log=self._tensorboard_log_path
|
|
)
|
|
|
|
def _save_checkpoint(self):
|
|
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
|
episode_count = self._env.episode_count
|
|
if checkpoint_n > 0 and episode_count > 0:
|
|
if (
|
|
(episode_count % checkpoint_n == 0)
|
|
or (episode_count == self._training_config.num_episodes)
|
|
):
|
|
self._agent.save(
|
|
self.checkpoints_path / f"sb3ppo_{episode_count}.zip")
|
|
|
|
def _get_latest_checkpoint(self):
|
|
pass
|
|
|
|
def learn(
|
|
self,
|
|
time_steps: Optional[int] = None,
|
|
episodes: Optional[int] = None,
|
|
**kwargs
|
|
):
|
|
if not time_steps:
|
|
time_steps = self._training_config.num_steps
|
|
|
|
if not episodes:
|
|
episodes = self._training_config.num_episodes
|
|
|
|
for i in range(episodes):
|
|
self._agent.learn(total_timesteps=time_steps)
|
|
self._save_checkpoint()
|
|
super().learn()
|
|
|
|
def evaluate(
|
|
self,
|
|
time_steps: Optional[int] = None,
|
|
episodes: Optional[int] = None,
|
|
deterministic: bool = True,
|
|
**kwargs
|
|
):
|
|
if not time_steps:
|
|
time_steps = self._training_config.num_steps
|
|
|
|
if not episodes:
|
|
episodes = self._training_config.num_episodes
|
|
|
|
for episode in range(episodes):
|
|
obs = self._env.reset()
|
|
|
|
for step in range(time_steps):
|
|
action, _states = self._agent.predict(
|
|
obs,
|
|
deterministic=deterministic
|
|
)
|
|
if isinstance(action, np.ndarray):
|
|
action = np.int64(action)
|
|
obs, rewards, done, info = self._env.step(action)
|
|
|
|
def load(self):
|
|
raise NotImplementedError
|
|
|
|
def save(self):
|
|
raise NotImplementedError
|
|
|
|
def export(self):
|
|
raise NotImplementedError
|