diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 7d0cde60..b4b0ec56 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -8,10 +8,11 @@ from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env - +import tensorflow as tf from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.common.enums import AgentFramework, AgentIdentifier, \ + DeepLearningFramework from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) @@ -115,7 +116,7 @@ class RLlibAgent(AgentSessionABC): train_batch_size=self._training_config.num_steps ) self._agent_config.framework( - framework="torch" + framework="tf" ) self._agent_config.rollouts( @@ -127,6 +128,7 @@ class RLlibAgent(AgentSessionABC): logger_creator=_custom_log_creator(self.session_path) ) + def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] @@ -154,8 +156,14 @@ class RLlibAgent(AgentSessionABC): for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() - self._agent.stop() + if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH: + policy = self._agent.get_policy() + tf.compat.v1.summary.FileWriter( + self.session_path / "ray_results", + policy.get_session().graph + ) super().learn() + self._agent.stop() def evaluate( self, diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 3748b57d..073eb2fe 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -86,10 +86,10 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() + self._env.close() super().learn() diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 502069ec..4a958fa6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,6 +3,7 @@ import copy import csv import logging +import time from datetime import datetime from pathlib import Path from typing import Dict, Tuple, Union, Final @@ -301,6 +302,8 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ + # Introduce a delay between steps + time.sleep(self.training_config.time_delay / 1000) if self.step_count == 0: print(f"Episode: {str(self.episode_count)}")