Added type hints

This commit is contained in:
Marek Wolan
2023-07-14 12:01:38 +01:00
parent a923d818d3
commit c57ed6edcd
16 changed files with 166 additions and 128 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Union
from typing import Any, Dict, Final, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
@@ -29,7 +29,7 @@ class PrimaiteSession:
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
):
) -> None:
"""
The PrimaiteSession constructor.
@@ -52,7 +52,7 @@ class PrimaiteSession:
self.learning_path: Path = None # noqa
self.evaluation_path: Path = None # noqa
def setup(self):
def setup(self) -> None:
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
@@ -123,8 +123,8 @@ class PrimaiteSession:
def learn(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Train the agent.
@@ -135,8 +135,8 @@ class PrimaiteSession:
def evaluate(
self,
**kwargs,
):
**kwargs: Any,
) -> None:
"""
Evaluate the agent.
@@ -145,6 +145,6 @@ class PrimaiteSession:
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(**kwargs)
def close(self):
def close(self) -> None:
"""Closes the agent."""
self._agent_session.close()