diff --git a/src/primaite/game/policy/rllib.py b/src/primaite/game/policy/rllib.py index f45b9fd6..fcebf40d 100644 --- a/src/primaite/game/policy/rllib.py +++ b/src/primaite/game/policy/rllib.py @@ -2,13 +2,15 @@ from pathlib import Path from typing import Literal, Optional, TYPE_CHECKING from primaite.game.policy.policy import PolicyABC -from primaite.session.environment import PrimaiteRayEnv +from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv if TYPE_CHECKING: from primaite.session.session import PrimaiteSession, TrainingOptions import ray +from ray import air, tune from ray.rllib.algorithms import ppo +from ray.rllib.algorithms.ppo import PPOConfig class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): @@ -54,3 +56,50 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"): def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy": """Create a policy from a config.""" return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) + + +class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"): + """Mutli agent RL policy using Ray RLLib.""" + + def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None): + """Initialise multi agent policy wrapper.""" + super().__init__(session=session) + + self.config = ( + PPOConfig() + .environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game}) + .rollouts(num_rollout_workers=0) + .multi_agent( + policies={agent.agent_name for agent in session.game.rl_agents}, + policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id, + ) + .training(train_batch_size=128) + ) + + def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: + """Train the agent.""" + tune.Tuner( + "PPO", + run_config=air.RunConfig( + stop={"training_iteration": n_episodes * timesteps_per_episode}, + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10), + ), + param_space=self.config, + ).fit() + + def load(self, model_path: Path) -> None: + """Load policy paramters from a file.""" + return NotImplemented + + def eval(self, n_episodes: int, deterministic: bool) -> None: + """Evaluate trained policy.""" + return NotImplemented + + def save(self, save_path: Path) -> None: + """Save policy parameters to a file.""" + return NotImplemented + + @classmethod + def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy": + """Create policy from config.""" + return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed) diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml new file mode 100644 index 00000000..9d71e093 --- /dev/null +++ b/tests/assets/configs/multi_agent_session.yaml @@ -0,0 +1,1166 @@ +training_config: + rl_framework: RLLIB_multi_agent + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 2 + n_eval_episodes: 1 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: #not used :( + - defender1 + - defender2 + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + + +game: + max_episode_length: 128 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + #