From 50c9ef16cbca4757a493e67bf4632fe1c984a55d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 24 Nov 2023 09:18:18 +0000 Subject: [PATCH] Move policy module into session --- src/primaite/game/policy/__init__.py | 4 ---- src/primaite/session/policy/__init__.py | 4 ++++ src/primaite/{game => session}/policy/policy.py | 0 src/primaite/{game => session}/policy/rllib.py | 2 +- src/primaite/{game => session}/policy/sb3.py | 2 +- src/primaite/session/session.py | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) delete mode 100644 src/primaite/game/policy/__init__.py create mode 100644 src/primaite/session/policy/__init__.py rename src/primaite/{game => session}/policy/policy.py (100%) rename src/primaite/{game => session}/policy/rllib.py (98%) rename src/primaite/{game => session}/policy/sb3.py (98%) diff --git a/src/primaite/game/policy/__init__.py b/src/primaite/game/policy/__init__.py deleted file mode 100644 index 9c0e4199..00000000 --- a/src/primaite/game/policy/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from primaite.game.policy.rllib import RaySingleAgentPolicy -from primaite.game.policy.sb3 import SB3Policy - -__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/session/policy/__init__.py b/src/primaite/session/policy/__init__.py new file mode 100644 index 00000000..811c7a54 --- /dev/null +++ b/src/primaite/session/policy/__init__.py @@ -0,0 +1,4 @@ +from primaite.session.policy.rllib import RaySingleAgentPolicy +from primaite.session.policy.sb3 import SB3Policy + +__all__ = ["SB3Policy", "RaySingleAgentPolicy"] diff --git a/src/primaite/game/policy/policy.py b/src/primaite/session/policy/policy.py similarity index 100% rename from src/primaite/game/policy/policy.py rename to src/primaite/session/policy/policy.py diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/session/policy/rllib.py similarity index 98% rename from src/primaite/game/policy/rllib.py rename to src/primaite/session/policy/rllib.py index fcebf40d..7ba3edd0 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/session/policy/rllib.py @@ -1,8 +1,8 @@ from pathlib import Path from typing import Literal, Optional, TYPE_CHECKING -from primaite.game.policy.policy import PolicyABC from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.session.policy.policy import PolicyABC if TYPE_CHECKING: from primaite.session.session import PrimaiteSession, TrainingOptions diff --git a/src/primaite/game/policy/sb3.py b/src/primaite/session/policy/sb3.py similarity index 98% rename from src/primaite/game/policy/sb3.py rename to src/primaite/session/policy/sb3.py index 64eebfc7..051e2770 100644 --- a/src/primaite/game/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -8,7 +8,7 @@ from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy as PPO_MLP -from primaite.game.policy.policy import PolicyABC +from primaite.session.policy.policy import PolicyABC if TYPE_CHECKING: from primaite.session.session import PrimaiteSession, TrainingOptions diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 9f567a95..80b63ba7 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -5,12 +5,12 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict from primaite.game.game import PrimaiteGame - -# from primaite.game.game import PrimaiteGame -from primaite.game.policy.policy import PolicyABC from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings +# from primaite.game.game import PrimaiteGame +from primaite.session.policy.policy import PolicyABC + class TrainingOptions(BaseModel): """Options for training the RL agent."""