Implement agent training with sb3
This commit is contained in:
@@ -112,7 +112,7 @@ game_config:
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: RLAgent
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from primaite.game.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy"]
|
||||
|
||||
@@ -16,7 +16,7 @@ class PolicyABC(ABC):
|
||||
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, name: str, **kwargs: Any) -> None:
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a policy subclass.
|
||||
|
||||
@@ -25,9 +25,9 @@ class PolicyABC(ABC):
|
||||
:raises ValueError: When attempting to create a policy with a duplicate name.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {name}")
|
||||
cls._registry[name] = cls
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
@@ -78,6 +78,6 @@ class PolicyABC(ABC):
|
||||
# I should really define a config schema class using pydantic.
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config()
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
|
||||
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from typing import Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
@@ -11,7 +12,7 @@ if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC):
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
@@ -39,16 +40,18 @@ class SB3Policy(PolicyABC):
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for i in range(n_episodes):
|
||||
self._agent.learn(total_timesteps=n_time_steps)
|
||||
self._save_checkpoint()
|
||||
# self._save_checkpoint()
|
||||
pass
|
||||
|
||||
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
|
||||
for episode in range(n_episodes):
|
||||
obs = self.session.env.reset()
|
||||
obs, info = self.session.env.reset()
|
||||
for step in range(n_time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, truncated, terminated, info = self.session.env.step(action)
|
||||
|
||||
def save(self) -> None:
|
||||
|
||||
@@ -33,7 +33,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteEnv(gymnasium.Env):
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
@@ -57,10 +57,10 @@ class PrimaiteEnv(gymnasium.Env):
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
next_obs = self.agent.observation_manager.current_observation
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = ...
|
||||
truncated = False
|
||||
info = {}
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
@@ -70,19 +70,25 @@ class PrimaiteEnv(gymnasium.Env):
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self.agent.observation_manager.current_observation
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.action_space
|
||||
return self.agent.action_manager.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return self.agent.observation_manager.observation_space
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
@@ -122,6 +128,9 @@ class PrimaiteSession:
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
|
||||
self.rl_agents: List[ProxyAgent] = []
|
||||
"""Subset of agent list including only the reinforcement learning agents."""
|
||||
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
@@ -149,7 +158,8 @@ class PrimaiteSession:
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
# self.env:
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session."""
|
||||
@@ -423,7 +433,7 @@ class PrimaiteSession:
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "RLAgent":
|
||||
elif agent_type == "ProxyAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
@@ -431,6 +441,7 @@ class PrimaiteSession:
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agents.append(new_agent)
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -442,7 +453,10 @@ class PrimaiteSession:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options)
|
||||
# CREATE ENVIRONMENT
|
||||
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user