#1595 missed hardcoded_abc file in commit
This commit is contained in:
108
src/primaite/agents/hardcoded_abc.py
Normal file
108
src/primaite/agents/hardcoded_abc.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path=None):
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
Reference in New Issue
Block a user