Fix typehints

This commit is contained in:
Marek Wolan
2023-11-16 15:28:38 +00:00
parent 13c49bf3ea
commit 0b9bdedebd
5 changed files with 8 additions and 10 deletions

View File

@@ -26,7 +26,7 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
class RewardFunction:
"""Manages the reward function for the agent."""
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,

View File

@@ -1,7 +1,7 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, TYPE_CHECKING
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, type["PolicyABC"]] = {}
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.

View File

@@ -1,6 +1,6 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING, Union
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
@@ -21,7 +21,7 @@ class SB3Policy(PolicyABC, identifier="SB3"):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: type[Union[PPO, A2C]]
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP

View File

@@ -52,7 +52,7 @@ class PrimaiteGymEnv(gymnasium.Env):
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
@@ -70,7 +70,7 @@ class PrimaiteGymEnv(gymnasium.Env):
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]:
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()

View File

@@ -40,8 +40,6 @@ class TestPrimaiteSession:
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
for i in range(100):
print(session.io_manager.generate_session_path())
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)