diff --git a/pyproject.toml b/pyproject.toml index 1e074c25..2f8cb803 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ "tensorflow==2.12.0", "typer[all]==0.9.0", "pydantic==2.1.1", - "enlighten==1.12.2" + "enlighten==1.12.2", + "ray[rllib] == 2.8.0, < 3" ] [tool.setuptools.dynamic] diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py new file mode 100644 index 00000000..721a7500 --- /dev/null +++ b/src/primaite/game/policy/rllib.py @@ -0,0 +1,18 @@ + + +from typing import Literal, Optional, Type, TYPE_CHECKING, Union + +from primaite.game.policy import PolicyABC + +if TYPE_CHECKING: + from primaite.game.session import PrimaiteSession, TrainingOptions + +from ray.rllib + + +class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): + """Single agent RL policy using Ray RLLib.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None): + super().__init__(session=session) +