Ensure everything is still typehinted
This commit is contained in:
@@ -254,7 +254,7 @@ class AgentSessionABC(ABC):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def load(self, path: Union[str, Path]):
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load an agent from file."""
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
@@ -24,7 +26,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
@@ -37,7 +39,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -48,16 +50,16 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -66,13 +68,13 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -103,14 +105,14 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path=None):
|
||||
def load(cls, path: Union[str, Path] = None) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -10,7 +10,7 @@ from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Loads a session metadata from the given directory path.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user