Added type hints
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
@@ -29,7 +29,7 @@ class PrimaiteSession:
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
@@ -52,7 +52,7 @@ class PrimaiteSession:
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
def setup(self):
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
@@ -123,8 +123,8 @@ class PrimaiteSession:
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
@@ -135,8 +135,8 @@ class PrimaiteSession:
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
@@ -145,6 +145,6 @@ class PrimaiteSession:
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
Reference in New Issue
Block a user