Update session after it was split from game
This commit is contained in:
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame, TrainingOptions
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
@@ -32,7 +32,7 @@ class PolicyABC(ABC):
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteGame") -> None:
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
@@ -41,7 +41,7 @@ class PolicyABC(ABC):
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteGame" = session
|
||||
self.session: "PrimaiteSession" = session
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -69,7 +69,7 @@ class PolicyABC(ABC):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
@@ -80,5 +80,3 @@ class PolicyABC(ABC):
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
|
||||
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import TrainingOptions
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms import ppo
|
||||
@@ -17,64 +14,33 @@ from ray.rllib.algorithms import ppo
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
"""Single agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
ray.init()
|
||||
|
||||
class RayPrimaiteGym(gymnasium.Env):
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
self.action_space = session.env.action_space
|
||||
self.observation_space = session.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
obs, info = session.env.reset()
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
obs, reward, terminated, truncated, info = session.env.step(action)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
config = {
|
||||
"env": RayPrimaiteGym,
|
||||
"env_config": {},
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": {"game": session.game},
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
self._algo = ppo.PPO(config=config)
|
||||
|
||||
# self._agent_config = (PPOConfig()
|
||||
# .update_from_dict({
|
||||
# "num_gpus":0,
|
||||
# "num_workers":0,
|
||||
# "batch_mode":"complete_episodes",
|
||||
# "framework":"torch",
|
||||
# })
|
||||
# .environment(
|
||||
# env="primaite",
|
||||
# env_config={"session": session, "agents": session.rl_agents,},
|
||||
# # disable_env_checking=True
|
||||
# )
|
||||
# # .rollouts(num_rollout_workers=0,
|
||||
# # num_envs_per_worker=0)
|
||||
# # .framework("tf2")
|
||||
# .evaluation(evaluation_num_workers=0)
|
||||
# )
|
||||
|
||||
# self._agent:Algorithm = self._agent_config.build(use_copy=False)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
for ep in range(n_episodes):
|
||||
res = self._algo.train()
|
||||
print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}")
|
||||
self._algo.train()
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
raise NotImplementedError
|
||||
for ep in range(n_episodes):
|
||||
obs, info = self.session.env.reset()
|
||||
for step in range(self.session.game.options.max_episode_length):
|
||||
action = self._algo.compute_single_action(observation=obs, explore=False)
|
||||
obs, rew, term, trunc, info = self.session.env.step(action)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the policy to a file."""
|
||||
@@ -85,6 +51,6 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
@@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame, TrainingOptions
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
@@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy":
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
@@ -5,8 +5,8 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import load
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.config.load import example_config_path, load
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
@@ -32,7 +32,7 @@ def run(
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
@@ -42,6 +42,6 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.config:
|
||||
_LOGGER.error("Please provide a config file using the --config " "argument")
|
||||
args.config = example_config_path()
|
||||
|
||||
run(session_path=args.config)
|
||||
run(args.config)
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"\n",
|
||||
"from primaite.game.environment import PrimaiteRayEnv"
|
||||
"from primaite.session.environment import PrimaiteRayEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -61,7 +61,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.environment import PrimaiteRayMARLEnv\n",
|
||||
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"env_config = {\"game\":game}\n",
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"\n",
|
||||
"from primaite.game.environment import PrimaiteRayEnv"
|
||||
"from primaite.session.environment import PrimaiteRayEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -102,7 +102,11 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from primaite.config.load import example_config_path\n",
|
||||
"from primaite.main import run\n",
|
||||
"run(example_config_path())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -2,18 +2,18 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"from primaite.game.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -22,24 +22,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"installing DNSServer on node domain_controller\n",
|
||||
"installing DatabaseClient on node web_server\n",
|
||||
"installing WebServer on node web_server\n",
|
||||
"installing DatabaseService on node database_server\n",
|
||||
"service type not found DatabaseBackup\n",
|
||||
"installing DataManipulationBot on node client_1\n",
|
||||
"installing DNSClient on node client_1\n",
|
||||
"installing DNSClient on node client_2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(example_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
@@ -49,16 +34,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
|
||||
"gym = PrimaiteGymEnv(game=game)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -67,7 +52,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -76,27 +61,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<stable_baselines3.ppo.ppo.PPO at 0x7f75352526e0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=1000)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.game.environment import PrimaiteGymEnv
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
# from primaite.game.game import PrimaiteGame
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
|
||||
|
||||
@@ -15,7 +17,7 @@ class TrainingOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent"]
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
@@ -38,7 +40,7 @@ class SessionMode(Enum):
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
@@ -46,8 +48,8 @@ class PrimaiteSession:
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
|
||||
"""The environment that the RL algorithm can consume."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
@@ -55,6 +57,9 @@ class PrimaiteSession:
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game: PrimaiteGame = game
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
@@ -83,10 +88,26 @@ class PrimaiteSession:
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary."""
|
||||
sess = cls()
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
sess = cls(game=game)
|
||||
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config={"game": game})
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game=game)
|
||||
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
|
||||
Reference in New Issue
Block a user