diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 0c39333c..17e5f5a5 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -112,7 +112,7 @@ game_config: - ref: defender team: BLUE - type: RLAgent + type: ProxyAgent observation_space: type: UC2BlueObservation diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py index e69de29b..29196112 100644 --- a/src/primaite/game/policy/__init__.py +++ b/src/primaite/game/policy/__init__.py @@ -0,0 +1,3 @@ +from primaite.game.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index 5669a4ff..4c8dc447 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -16,7 +16,7 @@ class PolicyABC(ABC): Automatically populated when PolicyABC subclasses are defined. Used for defining from_config. """ - def __init_subclass__(cls, name: str, **kwargs: Any) -> None: + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: """ Register a policy subclass. @@ -25,9 +25,9 @@ class PolicyABC(ABC): :raises ValueError: When attempting to create a policy with a duplicate name. """ super().__init_subclass__(**kwargs) - if name in cls._registry: - raise ValueError(f"Duplicate policy name {name}") - cls._registry[name] = cls + if identifier in cls._registry: + raise ValueError(f"Duplicate policy name {identifier}") + cls._registry[identifier] = cls return @abstractmethod @@ -78,6 +78,6 @@ class PolicyABC(ABC): # I should really define a config schema class using pydantic. PolicyType = cls._registry[config.rl_framework] - return PolicyType.from_config() + return PolicyType.from_config(config=config, session=session) # saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 73df1b98..391b3115 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,6 +1,7 @@ """Stable baselines 3 policy.""" from typing import Literal, Optional, TYPE_CHECKING, Union +import numpy as np from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP from stable_baselines3.ppo import MlpPolicy as PPO_MLP @@ -11,7 +12,7 @@ if TYPE_CHECKING: from primaite.game.session import PrimaiteSession, TrainingOptions -class SB3Policy(PolicyABC): +class SB3Policy(PolicyABC, identifier="SB3"): """Single agent RL policy using stable baselines 3.""" def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): @@ -39,16 +40,18 @@ class SB3Policy(PolicyABC): # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB for i in range(n_episodes): self._agent.learn(total_timesteps=n_time_steps) - self._save_checkpoint() + # self._save_checkpoint() pass def eval(self, n_episodes: int, n_time_steps: int, deterministic: bool) -> None: """Evaluate the agent.""" # TODO: consider moving this loop to the session, only if this makes sense for RAY RLLIB for episode in range(n_episodes): - obs = self.session.env.reset() + obs, info = self.session.env.reset() for step in range(n_time_steps): action, _states = self._agent.predict(obs, deterministic=deterministic) + if isinstance(action, np.ndarray): + action = np.int64(action) obs, rewards, truncated, terminated, info = self.session.env.step(action) def save(self) -> None: diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 5556dd87..8017d0d4 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -33,7 +33,7 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) -class PrimaiteEnv(gymnasium.Env): +class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. @@ -57,10 +57,10 @@ class PrimaiteEnv(gymnasium.Env): state = self.session.get_sim_state() self.session.update_agents(state) - next_obs = self.agent.observation_manager.current_observation + next_obs = self._get_obs() reward = self.agent.reward_function.current_reward terminated = False - truncated = ... + truncated = False info = {} return next_obs, reward, terminated, truncated, info @@ -70,19 +70,25 @@ class PrimaiteEnv(gymnasium.Env): self.session.reset() state = self.session.get_sim_state() self.session.update_agents(state) - next_obs = self.agent.observation_manager.current_observation + next_obs = self._get_obs() info = {} return next_obs, info @property def action_space(self) -> gymnasium.Space: """Return the action space of the environment.""" - return self.agent.action_manager.action_space + return self.agent.action_manager.space @property def observation_space(self) -> gymnasium.Space: """Return the observation space of the environment.""" - return self.agent.observation_manager.observation_space + return gymnasium.spaces.flatten_space(self.agent.observation_manager.space) + + def _get_obs(self) -> ObsType: + """Return the current observation.""" + unflat_space = self.agent.observation_manager.space + unflat_obs = self.agent.observation_manager.current_observation + return gymnasium.spaces.flatten(unflat_space, unflat_obs) class PrimaiteSessionOptions(BaseModel): @@ -122,6 +128,9 @@ class PrimaiteSession: self.agents: List[AbstractAgent] = [] """List of agents.""" + self.rl_agents: List[ProxyAgent] = [] + """Subset of agent list including only the reinforcement learning agents.""" + self.step_counter: int = 0 """Current timestep within the episode.""" @@ -149,7 +158,8 @@ class PrimaiteSession: self.ref_map_links: Dict[str, Link] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" - # self.env: + self.env: PrimaiteGymEnv + """The environment that the agent can consume. Could be PrimaiteEnv.""" def start_session(self) -> None: """Commence the training session.""" @@ -423,7 +433,7 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) - elif agent_type == "RLAgent": + elif agent_type == "ProxyAgent": new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -431,6 +441,7 @@ class PrimaiteSession: reward_function=rew_function, ) sess.agents.append(new_agent) + sess.rl_agents.append(new_agent) elif agent_type == "RedDatabaseCorruptingAgent": new_agent = RandomAgent( agent_name=agent_cfg["ref"], @@ -442,7 +453,10 @@ class PrimaiteSession: else: print("agent type not found") - # CREATE POLICY - sess.policy = PolicyABC.from_config(sess.training_options) + # CREATE ENVIRONMENT + sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents) + + # CREATE POLICY + sess.policy = PolicyABC.from_config(sess.training_options, session=sess) return sess