diff --git a/docs/source/config.rst b/docs/source/config.rst index 81468f17..1bea0671 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -20,13 +20,6 @@ The environment config file consists of the following attributes: **Generic Config Values** -* **agent_identifier** [enum] - - This identifies the agent to use for the session. Select from one of the following: - - * GENERIC - Where a user developed agent is to be used - * STABLE_BASELINES3_PPO - Use a SB3 PPO agent - * STABLE_BASELINES3_A2C - use a SB3 A2C agent * **agent_framework** [enum] diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index c500128d..c9067210 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -1,14 +1,20 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Final, Dict, Any +from primaite import getLogger +from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite +_LOGGER = getLogger(__name__) + class AgentABC(ABC): @abstractmethod def __init__(self, env: Primaite): self._env: Primaite = env + self._training_config: Final[TrainingConfig] = self._env.training_config + self._lay_down_config: Dict[str, Any] = self._env.lay_down_config self._agent = None @abstractmethod @@ -33,4 +39,35 @@ class AgentABC(ABC): @abstractmethod def export(self): - pass \ No newline at end of file + pass + + +class DeterministicAgentABC(AgentABC): + @abstractmethod + def __init__(self, env: Primaite): + self._env: Primaite = env + self._agent = None + + @abstractmethod + def _setup(self): + pass + + def learn(self, time_steps: Optional[int], episodes: Optional[int]): + pass + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + @abstractmethod + def load(self): + pass + + @abstractmethod + def save(self): + pass + + @abstractmethod + def export(self): + pass diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index cb12210c..7d0fba3b 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,28 +1,35 @@ -# from typing import Optional -# -# from primaite.agents.agent_abc import AgentABC -# from primaite.environment.primaite_env import Primaite -# -# -# class SB3PPO(AgentABC): -# def __init__(self, env: Primaite): -# super().__init__(env) -# -# def _setup(self): -# if self._env.training_config -# pass -# -# def learn(self, time_steps: Optional[int], episodes: Optional[int]): -# pass -# -# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): -# pass -# -# def load(self): -# pass -# -# def save(self): -# pass -# -# def export(self): -# pass \ No newline at end of file +from typing import Optional + +from stable_baselines3 import PPO + +from primaite.agents.agent_abc import AgentABC +from primaite.environment.primaite_env import Primaite +from stable_baselines3.ppo import MlpPolicy as PPOMlp + +class SB3PPO(AgentABC): + def __init__(self, env: Primaite): + super().__init__(env) + + def _setup(self): + self._agent = PPO( + PPOMlp, + self._env, + verbose=0, + n_steps=self._training_config.num_steps + ) + + + def learn(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + def load(self): + pass + + def save(self): + pass + + def export(self): + pass \ No newline at end of file diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 3957e822..0efc0acf 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -7,6 +7,8 @@ from typing import Final, Optional, Union from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR +from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ + ActionType from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite @@ -61,7 +63,7 @@ class PrimaiteSession: self._env = None - self._training_config = None + self._training_config: TrainingConfig self._can_learn: bool = False _LOGGER.debug("") @@ -157,22 +159,70 @@ class PrimaiteSession: ): if self._can_learn: # Run environment against an agent - if self._training_config.agent_identifier == "GENERIC": - run_generic(env=env, config_values=config_values) - elif self._training_config == "STABLE_BASELINES3_PPO": - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - elif self._training_config == "STABLE_BASELINES3_A2C": - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) + if self._training_config.agent_framework == AgentFramework.NONE: + if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: + # Stochastic Random Agent + run_generic(env=env, config_values=config_values) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: + if self._training_config.action_type == ActionType.NODE: + # Deterministic Hardcoded Agent with Node Action Space + pass + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + pass + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + pass + + else: + # Invalid RedAgentIdentifier ActionType combo + pass + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.SB3: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Stable Baselines3/Proximal Policy Optimization + run_stable_baselines3_ppo( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Stable Baselines3/Advantage Actor Critic + run_stable_baselines3_a2c( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.RLLIB: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Ray RLlib/Proximal Policy Optimization + pass + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Ray RLlib/Advantage Actor Critic + pass + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + else: + # Invalid AgentFramework + pass print("Session finished") _LOGGER.debug("Session finished")