diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml new file mode 100644 index 00000000..9450c419 --- /dev/null +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -0,0 +1,1164 @@ +training_config: + rl_framework: RLLIB_single_agent + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 256 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_1_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + # + # - type: NODE_LOGON + # - type: NODE_LOGOFF + # - type: NODE_APPLICATION_EXECUTE + # options: + # execution_definition: + # target_address: arcd.com + + options: + nodes: + - node_ref: client_2 + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_nics_per_node: 2 + max_acl_rules: 10 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_step: 5 + frequency: 4 + variance: 3 + + - ref: client_1_data_manipulation_red_bot + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: + - node_ref: client_1 + observations: + - logon_status + - operating_status + services: + - service_ref: data_manipulation_bot + observations: + operating_status + health_status + folders: {} + + action_space: + action_list: + - type: DONOTHING + # Tuple[ObsType, SupportsFloat, bool, bool, Dict]: """Perform a step in the environment.""" return self.env.step(action) + + +class PrimaiteRayMARLEnv(MultiAgentEnv): + """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" + + def __init__(self, env_config: Optional[Dict] = None) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict[str, PrimaiteGame] + """ + self.game: PrimaiteGame = env_config["game"] + """Reference to the primaite game""" + self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} + """List of all possible agents in the environment. This list should not change!""" + self._agent_ids = list(self.agents.keys()) + + self.terminateds = set() + self.truncateds = set() + self.observation_space = gymnasium.spaces.Dict( + {name: agent.observation_manager.space for name, agent in self.agents.items()} + ) + self.action_space = gymnasium.spaces.Dict( + {name: agent.action_manager.space for name, agent in self.agents.items()} + ) + super().__init__() + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + self.game.reset() + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + def step( + self, actions: Dict[str, ActType] + ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: + """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. + + :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. + :type actions: Dict[str, ActType] + :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent + identifier. + :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] + """ + # 1. Perform actions + for agent_name, action in actions.items(): + self.agents[agent_name].store_action(action) + self.game.apply_agent_actions() + + # 2. Advance timestep + self.game.advance_timestep() + + # 3. Get next observations + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + + # 4. Get rewards + rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + terminateds = {name: False for name, _ in self.agents.items()} + truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} + infos = {} + terminateds["__all__"] = len(self.terminateds) == len(self.agents) + truncateds["__all__"] = self.game.calculate_truncated() + return next_obs, rewards, terminateds, truncateds, infos + + def _get_obs(self) -> Dict[str, ObsType]: + """Return the current observation.""" + return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()} diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb new file mode 100644 index 00000000..9f916af9 --- /dev/null +++ b/src/primaite/notebooks/training_example_ray_multi_agent.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.game import PrimaiteGame\n", + "import yaml\n", + "from primaite.config.load import example_config_path\n", + "\n", + "from primaite.game.environment import PrimaiteRayEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(example_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "game = PrimaiteGame.from_config(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# gym = PrimaiteRayEnv({\"game\":game})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "from ray import air, tune\n", + "from ray.rllib.algorithms.ppo import PPOConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()\n", + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.game.environment import PrimaiteRayMARLEnv\n", + "\n", + "\n", + "env_config = {\"game\":game}\n", + "config = (\n", + " PPOConfig()\n", + " .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n", + " .rollouts(num_rollout_workers=0)\n", + " .multi_agent(\n", + " policies={agent.agent_name for agent in game.rl_agents},\n", + " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", + " )\n", + " .training(train_batch_size=128)\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tune.Tuner(\n", + " \"PPO\",\n", + " run_config=air.RunConfig(\n", + " stop={\"training_iteration\": 128},\n", + " checkpoint_config=air.CheckpointConfig(\n", + " checkpoint_frequency=10,\n", + " ),\n", + " ),\n", + " param_space=config\n", + ").fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}