#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
@@ -9,18 +9,8 @@ from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import (
|
||||
DoNothingACLAgent,
|
||||
DoNothingNodeAgent,
|
||||
DummyAgent,
|
||||
RandomAgent,
|
||||
)
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
SessionType,
|
||||
)
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
@@ -49,16 +39,12 @@ class PrimaiteSession:
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path
|
||||
)
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(
|
||||
self._lay_down_config_path
|
||||
)
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
@@ -69,28 +55,16 @@ class PrimaiteSession:
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}"
|
||||
)
|
||||
if (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.HARDCODED
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.HARDCODED}"
|
||||
)
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -100,24 +74,14 @@ class PrimaiteSession:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.DO_NOTHING
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.DO_NOTHINGD}"
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHINGD}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -127,49 +91,26 @@ class PrimaiteSession:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.RANDOM
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.RANDOM}"
|
||||
)
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
elif (
|
||||
self._training_config.agent_identifier == AgentIdentifier.DUMMY
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.DUMMY}"
|
||||
)
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}"
|
||||
)
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}"
|
||||
)
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
@@ -182,35 +123,27 @@ class PrimaiteSession:
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of time steps per episode.
|
||||
:param episodes: The number of episodes.
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
self._agent_session.learn(**kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of time steps per episode.
|
||||
:param episodes: The number of episodes.
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Closes the agent."""
|
||||
|
||||
Reference in New Issue
Block a user