Fix incorrect number of steps per episode

This commit is contained in:
Marek Wolan
2023-11-15 14:49:44 +00:00
parent c8f2f193bd
commit 6182b53bfd
4 changed files with 69 additions and 32 deletions

View File

@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}

View File

@@ -2,10 +2,9 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 20
n_learn_steps: 128
n_learn_steps: 2560
n_eval_episodes: 5
n_eval_steps: 128
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:

View File

@@ -1,9 +1,9 @@
"""Stable baselines 3 policy."""
from typing import Literal, Optional, TYPE_CHECKING, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
@@ -33,26 +33,22 @@ class SB3Policy(PolicyABC, identifier="SB3"):
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
) # TODO: populate values once I figure out how to get them from the config / session
)
def learn(self, n_episodes: int, n_time_steps: int) -> None:
def learn(self, n_time_steps: int) -> None:
"""Train the agent."""
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for i in range(n_episodes):
self._agent.learn(total_timesteps=n_time_steps)
# self._save_checkpoint()
pass
self._agent.learn(total_timesteps=n_time_steps)
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for episode in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(n_time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, truncated, terminated, info = self.session.env.step(action)
reward_data = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
print(reward_data)
def save(self) -> None:
"""Save the agent."""

View File

@@ -1,7 +1,9 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
import enlighten
import gymnasium
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
@@ -30,6 +32,8 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
progress_bar_manager = enlighten.get_manager()
_LOGGER = getLogger(__name__)
@@ -60,7 +64,7 @@ class PrimaiteGymEnv(gymnasium.Env):
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = False
truncated = self.session.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
@@ -108,15 +112,22 @@ class TrainingOptions(BaseModel):
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
seed: Optional[int]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int = 0
n_eval_steps: Optional[int] = None
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
deterministic_eval: bool
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
@@ -161,18 +172,31 @@ class PrimaiteSession:
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.training_progress_bar: Optional[enlighten.Counter] = None
"""training steps counter"""
self.eval_progress_bar: Optional[enlighten.Counter] = None
"""evaluation episodes counter"""
self.mode: SessionMode = SessionMode.MANUAL
def start_session(self) -> None:
"""Commence the training session."""
self.mode = SessionMode.TRAIN
self.training_progress_bar = progress_bar_manager.counter(
total=self.training_options.n_learn_steps, desc="Training steps"
)
n_learn_steps = self.training_options.n_learn_steps
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_steps = self.training_options.n_eval_steps
n_eval_episodes = self.training_options.n_eval_episodes
deterministic_eval = True # TODO: get this value from config
if n_learn_episodes > 0:
self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps)
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(n_time_steps=n_learn_steps)
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval)
self.eval_progress_bar = progress_bar_manager.counter(total=n_eval_episodes, desc="Evaluation episodes")
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def step(self):
"""
@@ -227,12 +251,29 @@ class PrimaiteSession:
def advance_timestep(self) -> None:
"""Advance timestep."""
self.simulation.apply_timestep(self.step_counter)
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.simulation.apply_timestep(self.step_counter)
if self.training_progress_bar and self.mode == SessionMode.TRAIN:
self.training_progress_bar.update()
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
return NotImplemented
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
if self.eval_progress_bar and self.mode == SessionMode.EVAL:
self.eval_progress_bar.update()
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""