#917 - started working on the Agent abstract classes and sub-classes
This commit is contained in:
@@ -20,13 +20,6 @@ The environment config file consists of the following attributes:
|
|||||||
|
|
||||||
**Generic Config Values**
|
**Generic Config Values**
|
||||||
|
|
||||||
* **agent_identifier** [enum]
|
|
||||||
|
|
||||||
This identifies the agent to use for the session. Select from one of the following:
|
|
||||||
|
|
||||||
* GENERIC - Where a user developed agent is to be used
|
|
||||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
|
||||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
|
||||||
|
|
||||||
* **agent_framework** [enum]
|
* **agent_framework** [enum]
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional, Final, Dict, Any
|
||||||
|
|
||||||
|
from primaite import getLogger
|
||||||
|
from primaite.config.training_config import TrainingConfig
|
||||||
from primaite.environment.primaite_env import Primaite
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
|
||||||
|
_LOGGER = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentABC(ABC):
|
class AgentABC(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, env: Primaite):
|
def __init__(self, env: Primaite):
|
||||||
self._env: Primaite = env
|
self._env: Primaite = env
|
||||||
|
self._training_config: Final[TrainingConfig] = self._env.training_config
|
||||||
|
self._lay_down_config: Dict[str, Any] = self._env.lay_down_config
|
||||||
self._agent = None
|
self._agent = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -33,4 +39,35 @@ class AgentABC(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def export(self):
|
def export(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeterministicAgentABC(AgentABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, env: Primaite):
|
||||||
|
self._env: Primaite = env
|
||||||
|
self._agent = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _setup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||||
|
pass
|
||||||
|
_LOGGER.warning("Deterministic agents cannot learn")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def export(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,28 +1,35 @@
|
|||||||
# from typing import Optional
|
from typing import Optional
|
||||||
#
|
|
||||||
# from primaite.agents.agent_abc import AgentABC
|
from stable_baselines3 import PPO
|
||||||
# from primaite.environment.primaite_env import Primaite
|
|
||||||
#
|
from primaite.agents.agent_abc import AgentABC
|
||||||
#
|
from primaite.environment.primaite_env import Primaite
|
||||||
# class SB3PPO(AgentABC):
|
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||||
# def __init__(self, env: Primaite):
|
|
||||||
# super().__init__(env)
|
class SB3PPO(AgentABC):
|
||||||
#
|
def __init__(self, env: Primaite):
|
||||||
# def _setup(self):
|
super().__init__(env)
|
||||||
# if self._env.training_config
|
|
||||||
# pass
|
def _setup(self):
|
||||||
#
|
self._agent = PPO(
|
||||||
# def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
PPOMlp,
|
||||||
# pass
|
self._env,
|
||||||
#
|
verbose=0,
|
||||||
# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
n_steps=self._training_config.num_steps
|
||||||
# pass
|
)
|
||||||
#
|
|
||||||
# def load(self):
|
|
||||||
# pass
|
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||||
#
|
pass
|
||||||
# def save(self):
|
|
||||||
# pass
|
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||||
#
|
pass
|
||||||
# def export(self):
|
|
||||||
# pass
|
def load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def export(self):
|
||||||
|
pass
|
||||||
@@ -7,6 +7,8 @@ from typing import Final, Optional, Union
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from primaite import getLogger, SESSIONS_DIR
|
from primaite import getLogger, SESSIONS_DIR
|
||||||
|
from primaite.common.enums import AgentFramework, RedAgentIdentifier, \
|
||||||
|
ActionType
|
||||||
from primaite.config.training_config import TrainingConfig
|
from primaite.config.training_config import TrainingConfig
|
||||||
from primaite.environment.primaite_env import Primaite
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
|
||||||
@@ -61,7 +63,7 @@ class PrimaiteSession:
|
|||||||
|
|
||||||
|
|
||||||
self._env = None
|
self._env = None
|
||||||
self._training_config = None
|
self._training_config: TrainingConfig
|
||||||
self._can_learn: bool = False
|
self._can_learn: bool = False
|
||||||
_LOGGER.debug("")
|
_LOGGER.debug("")
|
||||||
|
|
||||||
@@ -157,22 +159,70 @@ class PrimaiteSession:
|
|||||||
):
|
):
|
||||||
if self._can_learn:
|
if self._can_learn:
|
||||||
# Run environment against an agent
|
# Run environment against an agent
|
||||||
if self._training_config.agent_identifier == "GENERIC":
|
if self._training_config.agent_framework == AgentFramework.NONE:
|
||||||
run_generic(env=env, config_values=config_values)
|
if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM:
|
||||||
elif self._training_config == "STABLE_BASELINES3_PPO":
|
# Stochastic Random Agent
|
||||||
run_stable_baselines3_ppo(
|
run_generic(env=env, config_values=config_values)
|
||||||
env=env,
|
|
||||||
config_values=config_values,
|
elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED:
|
||||||
session_path=session_dir,
|
if self._training_config.action_type == ActionType.NODE:
|
||||||
timestamp_str=timestamp_str,
|
# Deterministic Hardcoded Agent with Node Action Space
|
||||||
)
|
pass
|
||||||
elif self._training_config == "STABLE_BASELINES3_A2C":
|
|
||||||
run_stable_baselines3_a2c(
|
elif self._training_config.action_type == ActionType.ACL:
|
||||||
env=env,
|
# Deterministic Hardcoded Agent with ACL Action Space
|
||||||
config_values=config_values,
|
pass
|
||||||
session_path=session_dir,
|
|
||||||
timestamp_str=timestamp_str,
|
elif self._training_config.action_type == ActionType.ANY:
|
||||||
)
|
# Deterministic Hardcoded Agent with ANY Action Space
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Invalid RedAgentIdentifier ActionType combo
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Invalid AgentFramework RedAgentIdentifier combo
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||||
|
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||||
|
# Stable Baselines3/Proximal Policy Optimization
|
||||||
|
run_stable_baselines3_ppo(
|
||||||
|
env=env,
|
||||||
|
config_values=config_values,
|
||||||
|
session_path=session_dir,
|
||||||
|
timestamp_str=timestamp_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||||
|
# Stable Baselines3/Advantage Actor Critic
|
||||||
|
run_stable_baselines3_a2c(
|
||||||
|
env=env,
|
||||||
|
config_values=config_values,
|
||||||
|
session_path=session_dir,
|
||||||
|
timestamp_str=timestamp_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Invalid AgentFramework RedAgentIdentifier combo
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||||
|
if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO:
|
||||||
|
# Ray RLlib/Proximal Policy Optimization
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C:
|
||||||
|
# Ray RLlib/Advantage Actor Critic
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Invalid AgentFramework RedAgentIdentifier combo
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Invalid AgentFramework
|
||||||
|
pass
|
||||||
|
|
||||||
print("Session finished")
|
print("Session finished")
|
||||||
_LOGGER.debug("Session finished")
|
_LOGGER.debug("Session finished")
|
||||||
|
|||||||
Reference in New Issue
Block a user