119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
|
import time
|
|
from abc import abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
from primaite import getLogger
|
|
from primaite.agents.agent_abc import AgentSessionABC
|
|
from primaite.environment.primaite_env import Primaite
|
|
|
|
_LOGGER = getLogger(__name__)
|
|
|
|
|
|
class HardCodedAgentSessionABC(AgentSessionABC):
|
|
"""
|
|
An Agent Session ABC for evaluation deterministic agents.
|
|
|
|
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
|
implemented.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
training_config_path: Optional[Union[str, Path]] = "",
|
|
lay_down_config_path: Optional[Union[str, Path]] = "",
|
|
session_path: Optional[Union[str, Path]] = None,
|
|
) -> None:
|
|
"""
|
|
Initialise a hardcoded agent session.
|
|
|
|
:param training_config_path: YAML file containing configurable items defined in
|
|
`primaite.config.training_config.TrainingConfig`
|
|
:type training_config_path: Union[path, str]
|
|
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
|
:type lay_down_config_path: Union[path, str]
|
|
"""
|
|
super().__init__(training_config_path, lay_down_config_path, session_path)
|
|
self._setup()
|
|
|
|
def _setup(self) -> None:
|
|
self._env: Primaite = Primaite(
|
|
training_config_path=self._training_config_path,
|
|
lay_down_config_path=self._lay_down_config_path,
|
|
session_path=self.session_path,
|
|
timestamp_str=self.timestamp_str,
|
|
)
|
|
super()._setup()
|
|
self._can_learn = False
|
|
self._can_evaluate = True
|
|
|
|
def _save_checkpoint(self) -> None:
|
|
pass
|
|
|
|
def _get_latest_checkpoint(self) -> None:
|
|
pass
|
|
|
|
def learn(
|
|
self,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Train the agent.
|
|
|
|
:param kwargs: Any agent-specific key-word args to be passed.
|
|
"""
|
|
_LOGGER.warning("Deterministic agents cannot learn")
|
|
|
|
@abstractmethod
|
|
def _calculate_action(self, obs: np.ndarray) -> None:
|
|
pass
|
|
|
|
def evaluate(
|
|
self,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Evaluate the agent.
|
|
|
|
:param kwargs: Any agent-specific key-word args to be passed.
|
|
"""
|
|
self._env.set_as_eval() # noqa
|
|
self.is_eval = True
|
|
|
|
time_steps = self._training_config.num_eval_steps
|
|
episodes = self._training_config.num_eval_episodes
|
|
|
|
obs = self._env.reset()
|
|
for episode in range(episodes):
|
|
# Reset env and collect initial observation
|
|
for step in range(time_steps):
|
|
# Calculate action
|
|
action = self._calculate_action(obs)
|
|
|
|
# Perform the step
|
|
obs, reward, done, info = self._env.step(action)
|
|
|
|
if done:
|
|
break
|
|
|
|
# Introduce a delay between steps
|
|
time.sleep(self._training_config.time_delay / 1000)
|
|
obs = self._env.reset()
|
|
self._env.close()
|
|
|
|
@classmethod
|
|
def load(cls, path: Union[str, Path] = None) -> None:
|
|
"""Load an agent from file."""
|
|
_LOGGER.warning("Deterministic agents cannot be loaded")
|
|
|
|
def save(self) -> None:
|
|
"""Save the agent."""
|
|
_LOGGER.warning("Deterministic agents cannot be saved")
|
|
|
|
def export(self) -> None:
|
|
"""Export the agent to transportable file format."""
|
|
_LOGGER.warning("Deterministic agents cannot be exported")
|