Files
PrimAITE/src/primaite/primaite_session.py
Chris McCarthy 73015802ec #917 - Integrated the PrimaiteSession into all tests.
- Ran a full pre-commit hook and thus encountered tons of fixes required
2023-06-30 09:08:13 +01:00

218 lines
7.8 KiB
Python

from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Optional, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import (
DoNothingACLAgent,
DoNothingNodeAgent,
DummyAgent,
RandomAgent,
)
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
SessionType,
)
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
_LOGGER = getLogger(__name__)
class PrimaiteSession:
"""
The PrimaiteSession class.
Provides a single learning and evaluation entry point for all training
and lay down configurations.
"""
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
"""
The PrimaiteSession constructor.
:param training_config_path: The training config path.
:param lay_down_config_path: The lay down config path.
"""
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path
)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(
self._lay_down_config_path
)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
self.timestamp_str: str = None # noqa
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}"
)
if (
self._training_config.agent_identifier
== AgentIdentifier.HARDCODED
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.HARDCODED}"
)
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif (
self._training_config.agent_identifier
== AgentIdentifier.DO_NOTHING
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.DO_NOTHINGD}"
)
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
raise NotImplementedError
else:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif (
self._training_config.agent_identifier
== AgentIdentifier.RANDOM
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.RANDOM}"
)
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path
)
elif (
self._training_config.agent_identifier == AgentIdentifier.DUMMY
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.DUMMY}"
)
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path
)
else:
# Invalid AgentFramework AgentIdentifier combo
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}"
)
# Stable Baselines3 Agent
self._agent_session = SB3Agent(
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}"
)
# Ray RLlib Agent
self._agent_session = RLlibAgent(
self._training_config_path, self._lay_down_config_path
)
else:
# Invalid AgentFramework
raise ValueError
self.session_path: Path = self._agent_session.session_path
self.timestamp_str: str = self._agent_session.timestamp_str
self.learning_path: Path = self._agent_session.learning_path
self.evaluation_path: Path = self._agent_session.evaluation_path
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of time steps per episode.
:param episodes: The number of episodes.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(time_steps, episodes, **kwargs)
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of time steps per episode.
:param episodes: The number of episodes.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(time_steps, episodes, **kwargs)
def close(self):
"""Closes the agent."""
self._agent_session.close()