#917 - started working on the Agent abstract classes and sub-classes
This commit is contained in:
@@ -20,13 +20,6 @@ The environment config file consists of the following attributes:
|
||||
|
||||
**Generic Config Values**
|
||||
|
||||
* **agent_identifier** [enum]
|
||||
|
||||
This identifies the agent to use for the session. Select from one of the following:
|
||||
|
||||
* GENERIC - Where a user developed agent is to be used
|
||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
||||
|
||||
* **agent_framework** [enum]
|
||||
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Final, Dict, Any
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class AgentABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: Primaite):
|
||||
self._env: Primaite = env
|
||||
self._training_config: Final[TrainingConfig] = self._env.training_config
|
||||
self._lay_down_config: Dict[str, Any] = self._env.lay_down_config
|
||||
self._agent = None
|
||||
|
||||
@abstractmethod
|
||||
@@ -33,4 +39,35 @@ class AgentABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class DeterministicAgentABC(AgentABC):
|
||||
@abstractmethod
|
||||
def __init__(self, env: Primaite):
|
||||
self._env: Primaite = env
|
||||
self._agent = None
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self):
|
||||
pass
|
||||
|
||||
@@ -1,28 +1,35 @@
|
||||
# from typing import Optional
|
||||
#
|
||||
# from primaite.agents.agent_abc import AgentABC
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
#
|
||||
#
|
||||
# class SB3PPO(AgentABC):
|
||||
# def __init__(self, env: Primaite):
|
||||
# super().__init__(env)
|
||||
#
|
||||
# def _setup(self):
|
||||
# if self._env.training_config
|
||||
# pass
|
||||
#
|
||||
# def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
# pass
|
||||
#
|
||||
# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
# pass
|
||||
#
|
||||
# def load(self):
|
||||
# pass
|
||||
#
|
||||
# def save(self):
|
||||
# pass
|
||||
#
|
||||
# def export(self):
|
||||
# pass
|
||||
from typing import Optional
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from primaite.agents.agent_abc import AgentABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
class SB3PPO(AgentABC):
|
||||
def __init__(self, env: Primaite):
|
||||
super().__init__(env)
|
||||
|
||||
def _setup(self):
|
||||
self._agent = PPO(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=0,
|
||||
n_steps=self._training_config.num_steps
|
||||
)
|
||||
|
||||
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def export(self):
|
||||
pass
|
||||
@@ -7,6 +7,8 @@ from typing import Final, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
|
||||
ActionType
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -61,7 +63,7 @@ class PrimaiteSession:
|
||||
|
||||
|
||||
self._env = None
|
||||
self._training_config = None
|
||||
self._training_config: TrainingConfig
|
||||
self._can_learn: bool = False
|
||||
_LOGGER.debug("")
|
||||
|
||||
@@ -157,22 +159,70 @@ class PrimaiteSession:
|
||||
):
|
||||
if self._can_learn:
|
||||
# Run environment against an agent
|
||||
if self._training_config.agent_identifier == "GENERIC":
|
||||
run_generic(env=env, config_values=config_values)
|
||||
elif self._training_config == "STABLE_BASELINES3_PPO":
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
elif self._training_config == "STABLE_BASELINES3_A2C":
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
if self._training_config.agent_framework == AgentFramework.NONE:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
|
||||
# Stochastic Random Agent
|
||||
run_generic(env=env, config_values=config_values)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
pass
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
pass
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid RedAgentIdentifier ActionType combo
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Stable Baselines3/Proximal Policy Optimization
|
||||
run_stable_baselines3_ppo(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Stable Baselines3/Advantage Actor Critic
|
||||
run_stable_baselines3_a2c(
|
||||
env=env,
|
||||
config_values=config_values,
|
||||
session_path=session_dir,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||
# Ray RLlib/Proximal Policy Optimization
|
||||
pass
|
||||
|
||||
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||
# Ray RLlib/Advantage Actor Critic
|
||||
pass
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework RedAgentIdentifier combo
|
||||
pass
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
pass
|
||||
|
||||
print("Session finished")
|
||||
_LOGGER.debug("Session finished")
|
||||
|
||||
Reference in New Issue
Block a user