Fix typehints
This commit is contained in:
@@ -26,7 +26,7 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
|
||||
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
class PolicyABC(ABC):
|
||||
"""Base class for reinforcement learning agents."""
|
||||
|
||||
_registry: Dict[str, type["PolicyABC"]] = {}
|
||||
_registry: Dict[str, Type["PolicyABC"]] = {}
|
||||
"""
|
||||
Registry of policy types, keyed by name.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING, Union
|
||||
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
@@ -21,7 +21,7 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self._agent_class: type[Union[PPO, A2C]]
|
||||
self._agent_class: Type[Union[PPO, A2C]]
|
||||
if algorithm == "PPO":
|
||||
self._agent_class = PPO
|
||||
policy = PPO_MLP
|
||||
|
||||
@@ -52,7 +52,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
@@ -70,7 +70,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]:
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
|
||||
@@ -40,8 +40,6 @@ class TestPrimaiteSession:
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
for i in range(100):
|
||||
print(session.io_manager.generate_session_path())
|
||||
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
|
||||
|
||||
Reference in New Issue
Block a user