Ensure everything is still typehinted

This commit is contained in:
Marek Wolan
2023-07-18 10:13:54 +01:00
parent a7a5fb8598
commit 393505b98b
3 changed files with 18 additions and 16 deletions

View File

@@ -254,7 +254,7 @@ class AgentSessionABC(ABC):
def _get_latest_checkpoint(self) -> None:
pass
def load(self, path: Union[str, Path]):
def load(self, path: Union[str, Path]) -> None:
"""Load an agent from file."""
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)

View File

@@ -2,7 +2,9 @@
import time
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
import numpy as np
from primaite import getLogger
from primaite.agents.agent_abc import AgentSessionABC
@@ -24,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
training_config_path: Optional[Union[str, Path]] = "",
lay_down_config_path: Optional[Union[str, Path]] = "",
session_path: Optional[Union[str, Path]] = None,
):
) -> None:
"""
Initialise a hardcoded agent session.
@@ -37,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
super().__init__(training_config_path, lay_down_config_path, session_path)
self._setup()
def _setup(self):
def _setup(self) -> None:
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
@@ -48,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._can_learn = False
self._can_evaluate = True
def _save_checkpoint(self):
def _save_checkpoint(self) -> None:
pass
def _get_latest_checkpoint(self):
def _get_latest_checkpoint(self) -> None:
pass
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -66,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
_LOGGER.warning("Deterministic agents cannot learn")
@abstractmethod
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> None:
pass
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -103,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._env.close()
@classmethod
def load(cls, path=None):
def load(cls, path: Union[str, Path] = None) -> None:
"""Load an agent from file."""
_LOGGER.warning("Deterministic agents cannot be loaded")
def save(self):
def save(self) -> None:
"""Save the agent."""
_LOGGER.warning("Deterministic agents cannot be saved")
def export(self):
def export(self) -> None:
"""Export the agent to transportable file format."""
_LOGGER.warning("Deterministic agents cannot be exported")

View File

@@ -1,7 +1,7 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import json
from pathlib import Path
from typing import Union
from typing import Any, Dict, Union
import yaml
@@ -10,7 +10,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]:
"""
Loads a session metadata from the given directory path.