Add typehints
This commit is contained in:
@@ -6,7 +6,7 @@ from bisect import bisect
|
||||
from logging import Formatter, Logger, LogRecord, StreamHandler
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final
|
||||
from typing import Any, Dict, Final
|
||||
|
||||
import pkg_resources
|
||||
import yaml
|
||||
@@ -16,7 +16,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
"""An instance of `PlatformDirs` set with appname='primaite'."""
|
||||
|
||||
|
||||
def _get_primaite_config():
|
||||
def _get_primaite_config() -> Dict:
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
@@ -72,7 +72,7 @@ class _LevelFormatter(Formatter):
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs):
|
||||
def __init__(self, formats: Dict[int, str], **kwargs: Any) -> str:
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
|
||||
@@ -141,7 +141,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
save_checkpoint = False
|
||||
|
||||
@@ -7,7 +7,7 @@ from primaite.common.enums import SoftwareState
|
||||
class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState):
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Initialise a service.
|
||||
|
||||
|
||||
@@ -216,7 +216,7 @@ class TrainingConfig:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class Node:
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a node.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class NodeStateInstructionGreen(object):
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user