Fix incorrect number of steps per episode
This commit is contained in:
@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"WARNING": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
|
||||
@@ -2,10 +2,9 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 20
|
||||
n_learn_steps: 128
|
||||
n_learn_steps: 2560
|
||||
n_eval_episodes: 5
|
||||
n_eval_steps: 128
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from typing import Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
@@ -33,26 +33,22 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
env=self.session.env,
|
||||
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
|
||||
seed=seed,
|
||||
) # TODO: populate values once I figure out how to get them from the config / session
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, n_time_steps: int) -> None:
|
||||
def learn(self, n_time_steps: int) -> None:
|
||||
"""Train the agent."""
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for i in range(n_episodes):
|
||||
self._agent.learn(total_timesteps=n_time_steps)
|
||||
# self._save_checkpoint()
|
||||
pass
|
||||
self._agent.learn(total_timesteps=n_time_steps)
|
||||
|
||||
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for episode in range(n_episodes):
|
||||
obs, info = self.session.env.reset()
|
||||
for step in range(n_time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, truncated, terminated, info = self.session.env.step(action)
|
||||
reward_data = evaluate_policy(
|
||||
self._agent,
|
||||
self.session.env,
|
||||
n_eval_episodes=n_episodes,
|
||||
deterministic=deterministic,
|
||||
return_episode_rewards=True,
|
||||
)
|
||||
print(reward_data)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
|
||||
import enlighten
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel
|
||||
@@ -30,6 +32,8 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
progress_bar_manager = enlighten.get_manager()
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +64,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = False
|
||||
truncated = self.session.calculate_truncated()
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
@@ -108,15 +112,22 @@ class TrainingOptions(BaseModel):
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
seed: Optional[int]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int = 0
|
||||
n_eval_steps: Optional[int] = None
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
deterministic_eval: bool
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
|
||||
|
||||
@@ -161,18 +172,31 @@ class PrimaiteSession:
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
self.training_progress_bar: Optional[enlighten.Counter] = None
|
||||
"""training steps counter"""
|
||||
|
||||
self.eval_progress_bar: Optional[enlighten.Counter] = None
|
||||
"""evaluation episodes counter"""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
self.training_progress_bar = progress_bar_manager.counter(
|
||||
total=self.training_options.n_learn_steps, desc="Training steps"
|
||||
)
|
||||
n_learn_steps = self.training_options.n_learn_steps
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_steps = self.training_options.n_eval_steps
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
deterministic_eval = True # TODO: get this value from config
|
||||
if n_learn_episodes > 0:
|
||||
self.policy.learn(n_episodes=n_learn_episodes, n_time_steps=n_learn_steps)
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(n_time_steps=n_learn_steps)
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, n_time_steps=n_eval_steps, deterministic=deterministic_eval)
|
||||
self.eval_progress_bar = progress_bar_manager.counter(total=n_eval_episodes, desc="Evaluation episodes")
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -227,12 +251,29 @@ class PrimaiteSession:
|
||||
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
self.step_counter += 1
|
||||
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
|
||||
if self.training_progress_bar and self.mode == SessionMode.TRAIN:
|
||||
self.training_progress_bar.update()
|
||||
|
||||
def calculate_truncated(self) -> bool:
|
||||
"""Calculate whether the episode is truncated."""
|
||||
current_step = self.step_counter
|
||||
max_steps = self.training_options.max_steps_per_episode
|
||||
if current_step >= max_steps:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
return NotImplemented
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(self.episode_counter)
|
||||
if self.eval_progress_bar and self.mode == SessionMode.EVAL:
|
||||
self.eval_progress_bar.update()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
|
||||
Reference in New Issue
Block a user