diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py new file mode 100644 index 00000000..8d5a9a08 --- /dev/null +++ b/src/primaite/game/policy/policy.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractclassmethod, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + + +class PolicyABC(ABC): + """Base class for reinforcement learning agents.""" + + @abstractmethod + def __init__(self, session: "PrimaiteSession") -> None: + """Initialize a reinforcement learning agent.""" + self.session: "PrimaiteSession" = session + pass + + @abstractmethod + def learn( + self, + ) -> None: + """Train the agent.""" + pass + + @abstractmethod + def eval( + self, + ) -> None: + """Evaluate the agent.""" + pass + + @abstractmethod + def save( + self, + ) -> None: + """Save the agent.""" + pass + + @abstractmethod + def load( + self, + ) -> None: + """Load agent from a file.""" + pass + + def close( + self, + ) -> None: + """Close the agent.""" + pass + + @abstractclassmethod + def from_config( + cls, + ) -> "PolicyABC": + """Create an agent from a config file.""" + pass + + # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py new file mode 100644 index 00000000..9c6b49ae --- /dev/null +++ b/src/primaite/game/policy/sb3.py @@ -0,0 +1,89 @@ +from typing import Literal, TYPE_CHECKING, Union + +from stable_baselines3 import A2C, PPO +from stable_baselines3.a2c import MlpPolicy as A2C_MLP +from stable_baselines3.ppo import MlpPolicy as PPO_MLP + +from primaite.game.policy.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession + + +class SB3Policy(PolicyABC): + """Single agent RL policy using stable baselines 3.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"]): + """Initialize a stable baselines 3 policy.""" + super().__init__(session=session) + + self._agent_class: type[Union[PPO, A2C]] + if algorithm == "PPO": + self._agent_class = PPO + policy = PPO_MLP + elif algorithm == "A2C": + self._agent_class = A2C + policy = A2C_MLP + else: + raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy") + self._agent = self._agent_class( + policy=policy, + env=self.session.env, + n_steps=..., + seed=..., + ) # TODO: populate values once I figure out how to get them from the config / session + + def learn( + self, + ) -> None: + """Train the agent.""" + time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session + episodes = 10 # TODO: populate values once I figure out how to get them from the config / session + for i in range(episodes): + self._agent.learn(total_timesteps=time_steps) + self._save_checkpoint() + pass + + def eval( + self, + ) -> None: + """Evaluate the agent.""" + time_steps = 9999 # TODO: populate values once I figure out how to get them from the config / session + num_episodes = 10 # TODO: populate values once I figure out how to get them from the config / session + deterministic = True # TODO: populate values once I figure out how to get them from the config / session + + for episode in range(num_episodes): + obs = self.session.env.reset() + for step in range(time_steps): + action, _states = self._agent.predict(obs, deterministic=deterministic) + obs, rewards, truncated, terminated, info = self.session.env.step(action) + + def save( + self, + ) -> None: + """Save the agent.""" + savepath = ( + "temp/path/to/save.pth" # TODO: populate values once I figure out how to get them from the config / session + ) + self._agent.save(savepath) + pass + + def load( + self, + ) -> None: + """Load agent from a checkpoint.""" + self._agent_class.load("temp/path/to/save.pth", env=self.session.env) + pass + + def close( + self, + ) -> None: + """Close the agent.""" + pass + + @classmethod + def from_config( + self, + ) -> "SB3Policy": + """Create an agent from config file.""" + pass