#917 - started working on the Agent abstract classes and sub-classes

This commit is contained in:
Chris McCarthy
2023-06-15 09:48:44 +01:00
parent eb3368edd6
commit 6849939265
4 changed files with 141 additions and 54 deletions

View File

@@ -20,13 +20,6 @@ The environment config file consists of the following attributes:
**Generic Config Values** **Generic Config Values**
* **agent_identifier** [enum]
This identifies the agent to use for the session. Select from one of the following:
* GENERIC - Where a user developed agent is to be used
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
* **agent_framework** [enum] * **agent_framework** [enum]

View File

@@ -1,14 +1,20 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional, Final, Dict, Any
from primaite import getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class AgentABC(ABC): class AgentABC(ABC):
@abstractmethod @abstractmethod
def __init__(self, env: Primaite): def __init__(self, env: Primaite):
self._env: Primaite = env self._env: Primaite = env
self._training_config: Final[TrainingConfig] = self._env.training_config
self._lay_down_config: Dict[str, Any] = self._env.lay_down_config
self._agent = None self._agent = None
@abstractmethod @abstractmethod
@@ -33,4 +39,35 @@ class AgentABC(ABC):
@abstractmethod @abstractmethod
def export(self): def export(self):
pass pass
class DeterministicAgentABC(AgentABC):
@abstractmethod
def __init__(self, env: Primaite):
self._env: Primaite = env
self._agent = None
@abstractmethod
def _setup(self):
pass
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
pass
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@abstractmethod
def load(self):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def export(self):
pass

View File

@@ -1,28 +1,35 @@
# from typing import Optional from typing import Optional
#
# from primaite.agents.agent_abc import AgentABC from stable_baselines3 import PPO
# from primaite.environment.primaite_env import Primaite
# from primaite.agents.agent_abc import AgentABC
# from primaite.environment.primaite_env import Primaite
# class SB3PPO(AgentABC): from stable_baselines3.ppo import MlpPolicy as PPOMlp
# def __init__(self, env: Primaite):
# super().__init__(env) class SB3PPO(AgentABC):
# def __init__(self, env: Primaite):
# def _setup(self): super().__init__(env)
# if self._env.training_config
# pass def _setup(self):
# self._agent = PPO(
# def learn(self, time_steps: Optional[int], episodes: Optional[int]): PPOMlp,
# pass self._env,
# verbose=0,
# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): n_steps=self._training_config.num_steps
# pass )
#
# def load(self):
# pass def learn(self, time_steps: Optional[int], episodes: Optional[int]):
# pass
# def save(self):
# pass def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
# pass
# def export(self):
# pass def load(self):
pass
def save(self):
pass
def export(self):
pass

View File

@@ -7,6 +7,8 @@ from typing import Final, Optional, Union
from uuid import uuid4 from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR from primaite import getLogger, SESSIONS_DIR
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
ActionType
from primaite.config.training_config import TrainingConfig from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite from primaite.environment.primaite_env import Primaite
@@ -61,7 +63,7 @@ class PrimaiteSession:
self._env = None self._env = None
self._training_config = None self._training_config: TrainingConfig
self._can_learn: bool = False self._can_learn: bool = False
_LOGGER.debug("") _LOGGER.debug("")
@@ -157,22 +159,70 @@ class PrimaiteSession:
): ):
if self._can_learn: if self._can_learn:
# Run environment against an agent # Run environment against an agent
if self._training_config.agent_identifier == "GENERIC": if self._training_config.agent_framework == AgentFramework.NONE:
run_generic(env=env, config_values=config_values) if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
elif self._training_config == "STABLE_BASELINES3_PPO": # Stochastic Random Agent
run_stable_baselines3_ppo( run_generic(env=env, config_values=config_values)
env=env,
config_values=config_values, elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
session_path=session_dir, if self._training_config.action_type == ActionType.NODE:
timestamp_str=timestamp_str, # Deterministic Hardcoded Agent with Node Action Space
) pass
elif self._training_config == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c( elif self._training_config.action_type == ActionType.ACL:
env=env, # Deterministic Hardcoded Agent with ACL Action Space
config_values=config_values, pass
session_path=session_dir,
timestamp_str=timestamp_str, elif self._training_config.action_type == ActionType.ANY:
) # Deterministic Hardcoded Agent with ANY Action Space
pass
else:
# Invalid RedAgentIdentifier ActionType combo
pass
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
elif self._training_config.agent_framework == AgentFramework.SB3:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Stable Baselines3/Proximal Policy Optimization
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Stable Baselines3/Advantage Actor Critic
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
elif self._training_config.agent_framework == AgentFramework.RLLIB:
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
# Ray RLlib/Proximal Policy Optimization
pass
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
# Ray RLlib/Advantage Actor Critic
pass
else:
# Invalid AgentFramework RedAgentIdentifier combo
pass
else:
# Invalid AgentFramework
pass
print("Session finished") print("Session finished")
_LOGGER.debug("Session finished") _LOGGER.debug("Session finished")