Files
PrimAITE/src/primaite/agents/hardcoded_abc.py

119 lines
3.6 KiB
Python
Raw Normal View History

# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import time
from abc import abstractmethod
from pathlib import Path
2023-07-18 10:13:54 +01:00
from typing import Any, Optional, Union
import numpy as np
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
class HardCodedAgentSessionABC(AgentSessionABC):
"""
An Agent Session ABC for evaluation deterministic agents.
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
implemented.
"""
def __init__(
self,
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
2023-07-18 10:13:54 +01:00
) -> None:
"""
Initialise a hardcoded agent session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
"""
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
2023-07-18 10:13:54 +01:00
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
super()._setup()
self._can_learn = False
self._can_evaluate = True
2023-07-18 10:13:54 +01:00
def _save_checkpoint(self) -> None:
pass
2023-07-18 10:13:54 +01:00
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
2023-07-18 10:13:54 +01:00
**kwargs: Any,
) -> None:
"""
Train the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
2023-07-18 10:13:54 +01:00
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
2023-07-18 10:13:54 +01:00
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
2023-07-14 11:21:59 +01:00
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation
for step in range(time_steps):
# Calculate action
action = self._calculate_action(obs)
# Perform the step
obs, reward, done, info = self._env.step(action)
if done:
break
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
@classmethod
2023-07-18 10:13:54 +01:00
def load(cls, path: Union[str, Path] = None) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
2023-07-18 10:13:54 +01:00
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
2023-07-18 10:13:54 +01:00
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")