diff --git a/.azure/azure-build-deploy-docs-pipeline.yml b/.azure/azure-build-deploy-docs-pipeline.yml index 01adce6d..8ebfe4d6 100644 --- a/.azure/azure-build-deploy-docs-pipeline.yml +++ b/.azure/azure-build-deploy-docs-pipeline.yml @@ -26,7 +26,7 @@ jobs: displayName: 'Install build dependencies' - script: | - pip install -e .[dev] + pip install -e .[dev,rl] displayName: 'Install PrimAITE for docs autosummary' - script: | diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index f0a1793e..a32ae20f 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -82,12 +82,12 @@ stages: - script: | PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) - python -m pip install $PRIMAITE_WHEEL[dev] + python -m pip install $PRIMAITE_WHEEL[dev,rl] displayName: 'Install PrimAITE' condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - script: | - forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]" + forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev,rl]" displayName: 'Install PrimAITE' condition: eq( variables['Agent.OS'], 'Windows_NT' ) diff --git a/.github/workflows/build-sphinx.yml b/.github/workflows/build-sphinx.yml index 82da1c6b..da20fbd3 100644 --- a/.github/workflows/build-sphinx.yml +++ b/.github/workflows/build-sphinx.yml @@ -49,7 +49,7 @@ jobs: - name: Install PrimAITE for docs autosummary run: | set -x - python -m pip install -e .[dev] + python -m pip install -e .[dev,rl] - name: Run build script for Sphinx pages env: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ed94ad97..1b85f4be 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -48,7 +48,7 @@ jobs: - name: Install PrimAITE run: | PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) - python -m pip install $PRIMAITE_WHEEL[dev] + python -m pip install $PRIMAITE_WHEEL[dev,rl] - name: Perform PrimAITE Setup run: | diff --git a/README.md b/README.md index 3fd73b53..68a8488b 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ cd ~\primaite python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory .\.venv\Scripts\activate -pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl +pip install primaite-3.0.0-py3-none-any.whl[rl] primaite setup ``` @@ -66,7 +66,7 @@ mkdir ~/primaite cd ~/primaite python3 -m venv .venv source .venv/bin/activate -pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl +pip install primaite-3.0.0-py3-none-any.whl[rl] primaite setup ``` @@ -105,7 +105,7 @@ source venv/bin/activate #### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies ```bash -python3 -m pip install -e .[dev] +python3 -m pip install -e .[dev,rl] ``` #### 6. Perform the PrimAITE setup: @@ -114,6 +114,9 @@ python3 -m pip install -e .[dev] primaite setup ``` +#### Note +*It is possible to install PrimAITE without Ray RLLib, StableBaselines3, or any deep learning libraries by omitting the `rl` flag in the pip install command.* + ### Running PrimAITE Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`. diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 7c91498c..6e6fc3e4 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -82,7 +82,7 @@ Install PrimAITE .. code-block:: bash :caption: Unix - pip install path/to/your/primaite.whl + pip install path/to/your/primaite.whl[rl] .. code-block:: powershell :caption: Windows (Powershell) @@ -133,12 +133,12 @@ of your choice: .. code-block:: bash :caption: Unix - pip install -e .[dev] + pip install -e .[dev,rl] .. code-block:: powershell :caption: Windows (Powershell) - pip install -e .[dev] + pip install -e .[dev,rl] To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`). diff --git a/pyproject.toml b/pyproject.toml index 5d913e1a..008f7c9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,11 +36,8 @@ dependencies = [ "polars==0.18.4", "prettytable==3.8.0", "PyYAML==6.0", - "stable-baselines3[extra]==2.1.0", - "tensorflow==2.12.0", "typer[all]==0.9.0", "pydantic==2.7.0", - "ray[rllib] >= 2.9, < 3", "ipywidgets" ] @@ -55,6 +52,11 @@ license-files = ["LICENSE"] [project.optional-dependencies] +rl = [ + "ray[rllib] >= 2.9, < 3", + "tensorflow==2.12.0", + "stable-baselines3[extra]==2.1.0", +] dev = [ "build==0.10.0", "flake8==6.0.0", diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index 1fb66405..9d458426 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -18,7 +18,7 @@ "import yaml\n", "from primaite.config.load import data_manipulation_config_path\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms import ppo\n", "from ray import air, tune\n", "import ray\n", @@ -97,7 +97,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 6c42c701..1e9faded 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union import gymnasium from gymnasium.core import ActType, ObsType -from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite import getLogger from primaite.game.agent.interface import ProxyAgent @@ -128,164 +127,3 @@ class PrimaiteGymEnv(gymnasium.Env): if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) - - -class PrimaiteRayEnv(gymnasium.Env): - """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" - - def __init__(self, env_config: Dict) -> None: - """Initialise the environment. - - :param env_config: A dictionary containing the environment configuration. - :type env_config: Dict - """ - self.env = PrimaiteGymEnv(env_config=env_config) - # self.env.episode_counter -= 1 - self.action_space = self.env.action_space - self.observation_space = self.env.observation_space - - def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: - """Reset the environment.""" - return self.env.reset(seed=seed) - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: - """Perform a step in the environment.""" - return self.env.step(action) - - def close(self): - """Close the simulation.""" - self.env.close() - - @property - def game(self) -> PrimaiteGame: - """Pass through game from env.""" - return self.env.game - - -class PrimaiteRayMARLEnv(MultiAgentEnv): - """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" - - def __init__(self, env_config: Dict) -> None: - """Initialise the environment. - - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict - """ - self.episode_counter: int = 0 - """Current episode number.""" - self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) - """Object that returns a config corresponding to the current episode.""" - self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) - """Handles IO for the environment. This produces sys logs, agent logs, etc.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) - """Reference to the primaite game""" - self._agent_ids = list(self.game.rl_agents.keys()) - """Agent ids. This is a list of strings of agent names.""" - - self.terminateds = set() - self.truncateds = set() - self.observation_space = gymnasium.spaces.Dict( - { - name: gymnasium.spaces.flatten_space(agent.observation_manager.space) - for name, agent in self.agents.items() - } - ) - self.action_space = gymnasium.spaces.Dict( - {name: agent.action_manager.space for name, agent in self.agents.items()} - ) - - super().__init__() - - @property - def agents(self) -> Dict[str, ProxyAgent]: - """Grab a fresh reference to the agents from this episode's game object.""" - return {name: self.game.rl_agents[name] for name in self._agent_ids} - - def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: - """Reset the environment.""" - rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} - _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") - - if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) - - self.episode_counter += 1 - self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) - self.game.setup_for_episode(episode=self.episode_counter) - state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() - info = {} - return next_obs, info - - def step( - self, actions: Dict[str, ActType] - ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: - """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. - - :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. - :type actions: Dict[str, ActType] - :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent - identifier. - :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] - """ - step = self.game.step_counter - # 1. Perform actions - for agent_name, action in actions.items(): - self.agents[agent_name].store_action(action) - self.game.pre_timestep() - self.game.apply_agent_actions() - - # 2. Advance timestep - self.game.advance_timestep() - - # 3. Get next observations - state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() - - # 4. Get rewards - rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} - _LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}") - terminateds = {name: False for name, _ in self.agents.items()} - truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {name: {} for name, _ in self.agents.items()} - terminateds["__all__"] = len(self.terminateds) == len(self.agents) - truncateds["__all__"] = self.game.calculate_truncated() - if self.game.save_step_metadata: - self._write_step_metadata_json(step, actions, state, rewards) - return next_obs, rewards, terminateds, truncateds, infos - - def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" - - output_dir.mkdir(parents=True, exist_ok=True) - path = output_dir / f"step_{step}.json" - - data = { - "episode": self.episode_counter, - "step": step, - "actions": {agent_name: int(action) for agent_name, action in actions.items()}, - "reward": rewards, - "state": state, - } - with open(path, "w") as file: - json.dump(data, file) - - def _get_obs(self) -> Dict[str, ObsType]: - """Return the current observation.""" - obs = {} - for agent_name in self._agent_ids: - agent = self.game.rl_agents[agent_name] - unflat_space = agent.observation_manager.space - unflat_obs = agent.observation_manager.current_observation - obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) - return obs - - def close(self): - """Close the simulation.""" - if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py new file mode 100644 index 00000000..5149a225 --- /dev/null +++ b/src/primaite/session/ray_envs.py @@ -0,0 +1,174 @@ +import json +from typing import Dict, SupportsFloat, Tuple + +import gymnasium +from gymnasium.core import ActType, ObsType +from ray.rllib.env.multi_agent_env import MultiAgentEnv + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.session.environment import _LOGGER, PrimaiteGymEnv +from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler +from primaite.session.io import PrimaiteIO +from primaite.simulator import SIM_OUTPUT + + +class PrimaiteRayMARLEnv(MultiAgentEnv): + """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" + + def __init__(self, env_config: Dict) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict + """ + self.episode_counter: int = 0 + """Current episode number.""" + self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) + """Object that returns a config corresponding to the current episode.""" + self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) + """Reference to the primaite game""" + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + + self.terminateds = set() + self.truncateds = set() + self.observation_space = gymnasium.spaces.Dict( + { + name: gymnasium.spaces.flatten_space(agent.observation_manager.space) + for name, agent in self.agents.items() + } + ) + self.action_space = gymnasium.spaces.Dict( + {name: agent.action_manager.space for name, agent in self.agents.items()} + ) + + super().__init__() + + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} + _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") + + if self.io.settings.save_agent_actions: + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + + self.episode_counter += 1 + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) + self.game.setup_for_episode(episode=self.episode_counter) + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + def step( + self, actions: Dict[str, ActType] + ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: + """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. + + :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. + :type actions: Dict[str, ActType] + :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent + identifier. + :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] + """ + step = self.game.step_counter + # 1. Perform actions + for agent_name, action in actions.items(): + self.agents[agent_name].store_action(action) + self.game.pre_timestep() + self.game.apply_agent_actions() + + # 2. Advance timestep + self.game.advance_timestep() + + # 3. Get next observations + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + + # 4. Get rewards + rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + _LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}") + terminateds = {name: False for name, _ in self.agents.items()} + truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} + infos = {name: {} for name, _ in self.agents.items()} + terminateds["__all__"] = len(self.terminateds) == len(self.agents) + truncateds["__all__"] = self.game.calculate_truncated() + if self.game.save_step_metadata: + self._write_step_metadata_json(step, actions, state, rewards) + return next_obs, rewards, terminateds, truncateds, infos + + def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict): + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{step}.json" + + data = { + "episode": self.episode_counter, + "step": step, + "actions": {agent_name: int(action) for agent_name, action in actions.items()}, + "reward": rewards, + "state": state, + } + with open(path, "w") as file: + json.dump(data, file) + + def _get_obs(self) -> Dict[str, ObsType]: + """Return the current observation.""" + obs = {} + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] + unflat_space = agent.observation_manager.space + unflat_obs = agent.observation_manager.current_observation + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + return obs + + def close(self): + """Close the simulation.""" + if self.io.settings.save_agent_actions: + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + + +class PrimaiteRayEnv(gymnasium.Env): + """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" + + def __init__(self, env_config: Dict) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict + """ + self.env = PrimaiteGymEnv(env_config=env_config) + # self.env.episode_counter -= 1 + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + return self.env.reset(seed=seed) + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + """Perform a step in the environment.""" + return self.env.step(action) + + def close(self): + """Close the simulation.""" + self.env.close() + + @property + def game(self) -> PrimaiteGame: + """Pass through game from env.""" + return self.env.game diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 712a16c4..9b550dd2 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -3,7 +3,7 @@ import yaml from ray import air, tune from ray.rllib.algorithms.ppo import PPOConfig -from primaite.session.environment import PrimaiteRayMARLEnv +from primaite.session.ray_envs import PrimaiteRayMARLEnv from tests import TEST_ASSETS_ROOT MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index d9057fef..f56f0f85 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -8,7 +8,7 @@ from ray.rllib.algorithms import ppo from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame -from primaite.session.environment import PrimaiteRayEnv +from primaite.session.ray_envs import PrimaiteRayEnv @pytest.mark.skip(reason="Slow, reenable later") diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index accfad50..0a2c6add 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -4,7 +4,8 @@ import yaml from gymnasium.core import ObsType from numpy import ndarray -from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayMARLEnv from primaite.simulator.network.hardware.nodes.host.server import Printer from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py index 6b40fb1a..c6fd1a2f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py +++ b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py @@ -1,7 +1,8 @@ import pytest import yaml -from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv from tests.conftest import TEST_ASSETS_ROOT folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"