#917 - Got things working'ish

This commit is contained in:
Chris McCarthy
2023-06-20 22:29:46 +01:00
parent a2cc4233b5
commit 5a6fdf58d4
3 changed files with 16 additions and 5 deletions

View File

@@ -8,10 +8,11 @@ from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import register_env
import tensorflow as tf
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.common.enums import AgentFramework, AgentIdentifier, \
DeepLearningFramework
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
@@ -115,7 +116,7 @@ class RLlibAgent(AgentSessionABC):
train_batch_size=self._training_config.num_steps
)
self._agent_config.framework(
framework="torch"
framework="tf"
)
self._agent_config.rollouts(
@@ -127,6 +128,7 @@ class RLlibAgent(AgentSessionABC):
logger_creator=_custom_log_creator(self.session_path)
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
@@ -154,8 +156,14 @@ class RLlibAgent(AgentSessionABC):
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self._agent.stop()
if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH:
policy = self._agent.get_policy()
tf.compat.v1.summary.FileWriter(
self.session_path / "ray_results",
policy.get_session().graph
)
super().learn()
self._agent.stop()
def evaluate(
self,

View File

@@ -86,10 +86,10 @@ class SB3Agent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.close()
super().learn()

View File

@@ -3,6 +3,7 @@
import copy
import csv
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Tuple, Union, Final
@@ -301,6 +302,8 @@ class Primaite(Env):
done: Indicates episode is complete if True
step_info: Additional information relating to this step
"""
# Introduce a delay between steps
time.sleep(self.training_config.time_delay / 1000)
if self.step_count == 0:
print(f"Episode: {str(self.episode_count)}")