diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index 9b0dd031..af860996 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -254,7 +254,7 @@ class AgentSessionABC(ABC): def _get_latest_checkpoint(self) -> None: pass - def load(self, path: Union[str, Path]): + def load(self, path: Union[str, Path]) -> None: """Load an agent from file.""" md_dict, training_config_path, laydown_config_path = parse_session_metadata(path) diff --git a/src/primaite/agents/hardcoded_abc.py b/src/primaite/agents/hardcoded_abc.py index ec4b53e7..0336f00e 100644 --- a/src/primaite/agents/hardcoded_abc.py +++ b/src/primaite/agents/hardcoded_abc.py @@ -2,7 +2,9 @@ import time from abc import abstractmethod from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union + +import numpy as np from primaite import getLogger from primaite.agents.agent_abc import AgentSessionABC @@ -24,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): training_config_path: Optional[Union[str, Path]] = "", lay_down_config_path: Optional[Union[str, Path]] = "", session_path: Optional[Union[str, Path]] = None, - ): + ) -> None: """ Initialise a hardcoded agent session. @@ -37,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): super().__init__(training_config_path, lay_down_config_path, session_path) self._setup() - def _setup(self): + def _setup(self) -> None: self._env: Primaite = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, @@ -48,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC): self._can_learn = False self._can_evaluate = True - def _save_checkpoint(self): + def _save_checkpoint(self) -> None: pass - def _get_latest_checkpoint(self): + def _get_latest_checkpoint(self) -> None: pass def learn( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Train the agent. @@ -66,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC): _LOGGER.warning("Deterministic agents cannot learn") @abstractmethod - def _calculate_action(self, obs): + def _calculate_action(self, obs: np.ndarray) -> None: pass def evaluate( self, - **kwargs, - ): + **kwargs: Any, + ) -> None: """ Evaluate the agent. @@ -103,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC): self._env.close() @classmethod - def load(cls, path=None): + def load(cls, path: Union[str, Path] = None) -> None: """Load an agent from file.""" _LOGGER.warning("Deterministic agents cannot be loaded") - def save(self): + def save(self) -> None: """Save the agent.""" _LOGGER.warning("Deterministic agents cannot be saved") - def export(self): + def export(self) -> None: """Export the agent to transportable file format.""" _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/utils/session_metadata_parser.py b/src/primaite/utils/session_metadata_parser.py index eb3c3339..0b0eaaec 100644 --- a/src/primaite/utils/session_metadata_parser.py +++ b/src/primaite/utils/session_metadata_parser.py @@ -1,7 +1,7 @@ # Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import json from pathlib import Path -from typing import Union +from typing import Any, Dict, Union import yaml @@ -10,7 +10,7 @@ from primaite import getLogger _LOGGER = getLogger(__name__) -def parse_session_metadata(session_path: Union[Path, str], dict_only=False): +def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]: """ Loads a session metadata from the given directory path.