From 6182b53bfd6858d8c33d70ad8adfa1f8ca2dbabb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 15 Nov 2023 14:49:44 +0000 Subject: [PATCH] Fix incorrect number of steps per episode --- src/primaite/__init__.py | 1 + .../config/_package_data/example_config.yaml | 5 +- src/primaite/game/policy/sb3.py | 30 ++++----- src/primaite/game/session.py | 65 +++++++++++++++---- 4 files changed, 69 insertions(+), 32 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 30fc9ab9..789517f7 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict: "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARN, + "WARNING": logging.WARN, "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 17e5f5a5..dca9620f 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -2,10 +2,9 @@ training_config: rl_framework: SB3 rl_algorithm: PPO seed: 333 - n_learn_episodes: 20 - n_learn_steps: 128 + n_learn_steps: 2560 n_eval_episodes: 5 - n_eval_steps: 128 + max_steps_per_episode: 128 deterministic_eval: false n_agents: 1 agent_references: diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 391b3115..ff710944 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,9 +1,9 @@ """Stable baselines 3 policy.""" from typing import Literal, Optional, TYPE_CHECKING, Union -import numpy as np from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP from primaite.game.policy.policy import PolicyABC @@ -33,26 +33,22 @@ class SB3Policy(PolicyABC, identifier="SB3"): env=self.session.env, n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch seed=seed, - ) # TODO: populate values once I figure out how to get them from the config / session + ) - def learn(self, n_episodes: int, n_time_steps: int) -> None: + def learn(self, n_time_steps: int) -> None: """Train the agent.""" - # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB - for i in range(n_episodes): - self._agent.learn(total_timesteps=n_time_steps) - # self._save_checkpoint() - pass + self._agent.learn(total_timesteps=n_time_steps) - def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: + def eval(self, n_episodes: int, deterministic: bool) -> None: """Evaluate the agent.""" - # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB - for episode in range(n_episodes): - obs, info = self.session.env.reset() - for step in range(n_time_steps): - action, _states = self._agent.predict(obs, deterministic=deterministic) - if isinstance(action, np.ndarray): - action = np.int64(action) - obs, rewards, truncated, terminated, info = self.session.env.step(action) + reward_data = evaluate_policy( + self._agent, + self.session.env, + n_eval_episodes=n_episodes, + deterministic=deterministic, + return_episode_rewards=True, + ) + print(reward_data) def save(self) -> None: """Save the agent.""" diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 8017d0d4..e85328ef 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,7 +1,9 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" +from enum import Enum from ipaddress import IPv4Address from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple +import enlighten import gymnasium from gymnasium.core import ActType, ObsType from pydantic import BaseModel @@ -30,6 +32,8 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +progress_bar_manager = enlighten.get_manager() + _LOGGER = getLogger(__name__) @@ -60,7 +64,7 @@ class PrimaiteGymEnv(gymnasium.Env): next_obs = self._get_obs() reward = self.agent.reward_function.current_reward terminated = False - truncated = False + truncated = self.session.calculate_truncated() info = {} return next_obs, reward, terminated, truncated, info @@ -108,15 +112,22 @@ class TrainingOptions(BaseModel): rl_framework: Literal["SB3", "RLLIB"] rl_algorithm: Literal["PPO", "A2C"] seed: Optional[int] - n_learn_episodes: int n_learn_steps: int - n_eval_episodes: int = 0 - n_eval_steps: Optional[int] = None + n_eval_episodes: Optional[int] = None + max_steps_per_episode: int deterministic_eval: bool n_agents: int agent_references: List[str] +class SessionMode(Enum): + """Helper to keep track of the current session mode.""" + + TRAIN = "train" + EVAL = "eval" + MANUAL = "manual" + + class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments.""" @@ -161,18 +172,31 @@ class PrimaiteSession: self.env: PrimaiteGymEnv """The environment that the agent can consume. Could be PrimaiteEnv.""" + self.training_progress_bar: Optional[enlighten.Counter] = None + """training steps counter""" + + self.eval_progress_bar: Optional[enlighten.Counter] = None + """evaluation episodes counter""" + + self.mode: SessionMode = SessionMode.MANUAL + def start_session(self) -> None: """Commence the training session.""" + self.mode = SessionMode.TRAIN + self.training_progress_bar = progress_bar_manager.counter( + total=self.training_options.n_learn_steps, desc="Training steps" + ) n_learn_steps = self.training_options.n_learn_steps - n_learn_episodes = self.training_options.n_learn_episodes - n_eval_steps = self.training_options.n_eval_steps n_eval_episodes = self.training_options.n_eval_episodes - deterministic_eval = True # TODO: get this value from config - if n_learn_episodes > 0: - self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps) + deterministic_eval = self.training_options.deterministic_eval + self.policy.learn(n_time_steps=n_learn_steps) + self.mode = SessionMode.EVAL if n_eval_episodes > 0: - self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval) + self.eval_progress_bar = progress_bar_manager.counter(total=n_eval_episodes, desc="Evaluation episodes") + self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval) + + self.mode = SessionMode.MANUAL def step(self): """ @@ -227,12 +251,29 @@ class PrimaiteSession: def advance_timestep(self) -> None: """Advance timestep.""" - self.simulation.apply_timestep(self.step_counter) self.step_counter += 1 + _LOGGER.debug(f"Advancing timestep to {self.step_counter} ") + self.simulation.apply_timestep(self.step_counter) + + if self.training_progress_bar and self.mode == SessionMode.TRAIN: + self.training_progress_bar.update() + + def calculate_truncated(self) -> bool: + """Calculate whether the episode is truncated.""" + current_step = self.step_counter + max_steps = self.training_options.max_steps_per_episode + if current_step >= max_steps: + return True + return False def reset(self) -> None: """Reset the session, this will reset the simulation.""" - return NotImplemented + self.episode_counter += 1 + self.step_counter = 0 + _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") + self.simulation.reset_component_for_episode(self.episode_counter) + if self.eval_progress_bar and self.mode == SessionMode.EVAL: + self.eval_progress_bar.update() def close(self) -> None: """Close the session, this will stop the env and close the simulation."""