diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 49d56e67..da1331b0 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,7 +26,7 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, Type, TYPE_CHECKING from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward): class RewardFunction: """Manages the reward function for the agent.""" - __rew_class_identifiers: Dict[str, type[AbstractReward]] = { + __rew_class_identifiers: Dict[str, Type[AbstractReward]] = { "DUMMY": DummyReward, "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, diff --git a/src/primaite/game/policy/policy.py b/src/primaite/game/policy/policy.py index a7052367..249c3b52 100644 --- a/src/primaite/game/policy/policy.py +++ b/src/primaite/game/policy/policy.py @@ -1,7 +1,7 @@ """Base class and common logic for RL policies.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, Dict, Type, TYPE_CHECKING if TYPE_CHECKING: from primaite.game.session import PrimaiteSession, TrainingOptions @@ -10,7 +10,7 @@ if TYPE_CHECKING: class PolicyABC(ABC): """Base class for reinforcement learning agents.""" - _registry: Dict[str, type["PolicyABC"]] = {} + _registry: Dict[str, Type["PolicyABC"]] = {} """ Registry of policy types, keyed by name. diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/game/policy/sb3.py index 10f22e05..bb35775a 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/game/policy/sb3.py @@ -1,6 +1,6 @@ """Stable baselines 3 policy.""" from pathlib import Path -from typing import Literal, Optional, TYPE_CHECKING, Union +from typing import Literal, Optional, Type, TYPE_CHECKING, Union from stable_baselines3 import A2C, PPO from stable_baselines3.a2c import MlpPolicy as A2C_MLP @@ -21,7 +21,7 @@ class SB3Policy(PolicyABC, identifier="SB3"): """Initialize a stable baselines 3 policy.""" super().__init__(session=session) - self._agent_class: type[Union[PPO, A2C]] + self._agent_class: Type[Union[PPO, A2C]] if algorithm == "PPO": self._agent_class = PPO policy = PPO_MLP diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index a2e83cbb..88c1e061 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -52,7 +52,7 @@ class PrimaiteGymEnv(gymnasium.Env): self.session: "PrimaiteSession" = session self.agent: ProxyAgent = agents[0] - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) @@ -70,7 +70,7 @@ class PrimaiteGymEnv(gymnasium.Env): return next_obs, reward, terminated, truncated, info - def reset(self, seed: Optional[int] = None) -> tuple[ObsType, dict[str, Any]]: + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" self.session.reset() state = self.session.get_sim_state() diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index c6179e9a..5e1da4ff 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -40,8 +40,6 @@ class TestPrimaiteSession: with temp_primaite_session as session: session: TempPrimaiteSession session.start_session() - for i in range(100): - print(session.io_manager.generate_session_path()) # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)