#917 - Got things working'ish
This commit is contained in:
@@ -8,10 +8,11 @@ from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
import tensorflow as tf
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier, \
|
||||
DeepLearningFramework
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -115,7 +116,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._agent_config.framework(
|
||||
framework="torch"
|
||||
framework="tf"
|
||||
)
|
||||
|
||||
self._agent_config.rollouts(
|
||||
@@ -127,6 +128,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
logger_creator=_custom_log_creator(self.session_path)
|
||||
)
|
||||
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
@@ -154,8 +156,14 @@ class RLlibAgent(AgentSessionABC):
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self._agent.stop()
|
||||
if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH:
|
||||
policy = self._agent.get_policy()
|
||||
tf.compat.v1.summary.FileWriter(
|
||||
self.session_path / "ray_results",
|
||||
policy.get_session().graph
|
||||
)
|
||||
super().learn()
|
||||
self._agent.stop()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
|
||||
@@ -86,10 +86,10 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import copy
|
||||
import csv
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union, Final
|
||||
@@ -301,6 +302,8 @@ class Primaite(Env):
|
||||
done: Indicates episode is complete if True
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self.training_config.time_delay / 1000)
|
||||
if self.step_count == 0:
|
||||
print(f"Episode: {str(self.episode_count)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user