Update session after it was split from game

This commit is contained in:
Marek Wolan
2023-11-23 00:54:19 +00:00
parent 1fd5298fc5
commit 14ae8be5e2
9 changed files with 76 additions and 113 deletions

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame, TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
@@ -32,7 +32,7 @@ class PolicyABC(ABC):
return
@abstractmethod
def __init__(self, session: "PrimaiteGame") -> None:
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
@@ -41,7 +41,7 @@ class PolicyABC(ABC):
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteGame" = session
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
@@ -69,7 +69,7 @@ class PolicyABC(ABC):
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "PolicyABC":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
@@ -80,5 +80,3 @@ class PolicyABC(ABC):
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -1,14 +1,11 @@
from pathlib import Path
from typing import Dict, Literal, Optional, SupportsFloat, Tuple, TYPE_CHECKING
import gymnasium
from gymnasium.core import ActType, ObsType
from typing import Literal, Optional, TYPE_CHECKING
from primaite.game.policy.policy import PolicyABC
from primaite.session.environment import PrimaiteRayEnv
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
from primaite.session.session import TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray.rllib.algorithms import ppo
@@ -17,64 +14,33 @@ from ray.rllib.algorithms import ppo
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
ray.init()
class RayPrimaiteGym(gymnasium.Env):
def __init__(self, env_config: Dict) -> None:
self.action_space = session.env.action_space
self.observation_space = session.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
obs, info = session.env.reset()
return obs, info
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
obs, reward, terminated, truncated, info = session.env.step(action)
return obs, reward, terminated, truncated, info
ray.shutdown()
ray.init()
config = {
"env": RayPrimaiteGym,
"env_config": {},
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
ray.shutdown()
ray.init()
self._algo = ppo.PPO(config=config)
# self._agent_config = (PPOConfig()
# .update_from_dict({
# "num_gpus":0,
# "num_workers":0,
# "batch_mode":"complete_episodes",
# "framework":"torch",
# })
# .environment(
# env="primaite",
# env_config={"session": session, "agents": session.rl_agents,},
# # disable_env_checking=True
# )
# # .rollouts(num_rollout_workers=0,
# # num_envs_per_worker=0)
# # .framework("tf2")
# .evaluation(evaluation_num_workers=0)
# )
# self._agent:Algorithm = self._agent_config.build(use_copy=False)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
res = self._algo.train()
print(f"Episode {ep} complete, reward: {res['episode_reward_mean']}")
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
raise NotImplementedError
for ep in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(self.session.game.options.max_episode_length):
action = self._algo.compute_single_action(observation=obs, explore=False)
obs, rew, term, trunc, info = self.session.env.step(action)
def save(self, save_path: Path) -> None:
"""Save the policy to a file."""
@@ -85,6 +51,6 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "RaySingleAgentPolicy":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -11,13 +11,13 @@ from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame, TrainingOptions
from primaite.session.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteGame", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
@@ -75,6 +75,6 @@ class SB3Policy(PolicyABC, identifier="SB3"):
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteGame") -> "SB3Policy":
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -5,8 +5,8 @@ from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import load
from primaite.game.game import PrimaiteGame
from primaite.config.load import example_config_path, load
from primaite.session.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
@@ -32,7 +32,7 @@ def run(
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteGame.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()
@@ -42,6 +42,6 @@ if __name__ == "__main__":
args = parser.parse_args()
if not args.config:
_LOGGER.error("Please provide a config file using the --config " "argument")
args.config = example_config_path()
run(session_path=args.config)
run(args.config)

View File

@@ -10,7 +10,7 @@
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.game.environment import PrimaiteRayEnv"
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
@@ -61,7 +61,7 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.environment import PrimaiteRayMARLEnv\n",
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"\n",
"\n",
"env_config = {\"game\":game}\n",

View File

@@ -10,7 +10,7 @@
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.game.environment import PrimaiteRayEnv"
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
@@ -102,7 +102,11 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"from primaite.config.load import example_config_path\n",
"from primaite.main import run\n",
"run(example_config_path())"
]
}
],
"metadata": {

View File

@@ -2,18 +2,18 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"from primaite.game.environment import PrimaiteGymEnv\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -22,24 +22,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"installing DNSServer on node domain_controller\n",
"installing DatabaseClient on node web_server\n",
"installing WebServer on node web_server\n",
"installing DatabaseService on node database_server\n",
"service type not found DatabaseBackup\n",
"installing DataManipulationBot on node client_1\n",
"installing DNSClient on node client_1\n",
"installing DNSClient on node client_2\n"
]
}
],
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
@@ -49,16 +34,16 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteGymEnv(game=game, agents=game.rl_agents)"
"gym = PrimaiteGymEnv(game=game)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -67,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -76,27 +61,16 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<stable_baselines3.ppo.ppo.PPO at 0x7f75352526e0>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"model.learn(total_timesteps=1000)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [

View File

@@ -1,12 +1,14 @@
from enum import Enum
from typing import Dict, List, Literal, Optional
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.game.environment import PrimaiteGymEnv
from primaite.game.game import PrimaiteGame
# from primaite.game.game import PrimaiteGame
from primaite.game.policy.policy import PolicyABC
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
@@ -15,7 +17,7 @@ class TrainingOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent"]
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
@@ -38,7 +40,7 @@ class SessionMode(Enum):
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self):
def __init__(self, game: PrimaiteGame):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
@@ -46,8 +48,8 @@ class PrimaiteSession:
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
"""The environment that the RL algorithm can consume."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
@@ -55,6 +57,9 @@ class PrimaiteSession:
self.io_manager = SessionIO()
"""IO manager for the session."""
self.game: PrimaiteGame = game
"""Primaite Game object for managing main simulation loop and agents."""
def start_session(self) -> None:
"""Commence the training/eval session."""
self.mode = SessionMode.TRAIN
@@ -83,10 +88,26 @@ class PrimaiteSession:
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
sess = cls()
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"game": game})
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config={"game": game})
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game=game)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess