Implement agent training with sb3

This commit is contained in:
Marek Wolan
2023-11-15 12:52:18 +00:00
parent e6ead6e532
commit c8f2f193bd
5 changed files with 39 additions and 19 deletions

View File

@@ -112,7 +112,7 @@ game_config:
- ref: defender
team: BLUE
type: RLAgent
type: ProxyAgent
observation_space:
type: UC2BlueObservation

View File

@@ -0,0 +1,3 @@
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]

View File

@@ -16,7 +16,7 @@ class PolicyABC(ABC):
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, name: str, **kwargs: Any) -> None:
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
@@ -25,9 +25,9 @@ class PolicyABC(ABC):
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if name in cls._registry:
raise ValueError(f"Duplicate policy name {name}")
cls._registry[name] = cls
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
@@ -78,6 +78,6 @@ class PolicyABC(ABC):
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config()
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -1,6 +1,7 @@
"""Stable baselines 3 policy."""
from typing import Literal, Optional, TYPE_CHECKING, Union
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
@@ -11,7 +12,7 @@ if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC):
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
@@ -39,16 +40,18 @@ class SB3Policy(PolicyABC):
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for i in range(n_episodes):
self._agent.learn(total_timesteps=n_time_steps)
self._save_checkpoint()
# self._save_checkpoint()
pass
def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None:
"""Evaluate the agent."""
# TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB
for episode in range(n_episodes):
obs = self.session.env.reset()
obs, info = self.session.env.reset()
for step in range(n_time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, truncated, terminated, info = self.session.env.step(action)
def save(self) -> None:

View File

@@ -33,7 +33,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteEnv(gymnasium.Env):
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
@@ -57,10 +57,10 @@ class PrimaiteEnv(gymnasium.Env):
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self.agent.observation_manager.current_observation
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = ...
truncated = False
info = {}
return next_obs, reward, terminated, truncated, info
@@ -70,19 +70,25 @@ class PrimaiteEnv(gymnasium.Env):
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self.agent.observation_manager.current_observation
next_obs = self._get_obs()
info = {}
return next_obs, info
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.action_space
return self.agent.action_manager.space
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return self.agent.observation_manager.observation_space
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
@@ -122,6 +128,9 @@ class PrimaiteSession:
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.step_counter: int = 0
"""Current timestep within the episode."""
@@ -149,7 +158,8 @@ class PrimaiteSession:
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
# self.env:
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
def start_session(self) -> None:
"""Commence the training session."""
@@ -423,7 +433,7 @@ class PrimaiteSession:
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "RLAgent":
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
@@ -431,6 +441,7 @@ class PrimaiteSession:
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
@@ -442,7 +453,10 @@ class PrimaiteSession:
else:
print("agent type not found")
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options)
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
return sess