diff --git a/.azure/azure-build-deploy-docs-pipeline.yml b/.azure/azure-build-deploy-docs-pipeline.yml index 01adce6d..8ebfe4d6 100644 --- a/.azure/azure-build-deploy-docs-pipeline.yml +++ b/.azure/azure-build-deploy-docs-pipeline.yml @@ -26,7 +26,7 @@ jobs: displayName: 'Install build dependencies' - script: | - pip install -e .[dev] + pip install -e .[dev,rl] displayName: 'Install PrimAITE for docs autosummary' - script: | diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index f0a1793e..aea94807 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -14,31 +14,31 @@ parameters: - name: matrix type: object default: - - job_name: 'UbuntuPython38' - py: '3.8' - img: 'ubuntu-latest' - every_time: false - publish_coverage: false + # - job_name: 'UbuntuPython38' + # py: '3.8' + # img: 'ubuntu-latest' + # every_time: false + # publish_coverage: false - job_name: 'UbuntuPython311' py: '3.11' img: 'ubuntu-latest' every_time: true publish_coverage: true - - job_name: 'WindowsPython38' - py: '3.8' - img: 'windows-latest' - every_time: false - publish_coverage: false + # - job_name: 'WindowsPython38' + # py: '3.8' + # img: 'windows-latest' + # every_time: false + # publish_coverage: false - job_name: 'WindowsPython311' py: '3.11' img: 'windows-latest' every_time: false publish_coverage: false - - job_name: 'MacOSPython38' - py: '3.8' - img: 'macOS-latest' - every_time: false - publish_coverage: false + # - job_name: 'MacOSPython38' + # py: '3.8' + # img: 'macOS-latest' + # every_time: false + # publish_coverage: false - job_name: 'MacOSPython311' py: '3.11' img: 'macOS-latest' @@ -82,12 +82,12 @@ stages: - script: | PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) - python -m pip install $PRIMAITE_WHEEL[dev] + python -m pip install $PRIMAITE_WHEEL[dev,rl] displayName: 'Install PrimAITE' condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - script: | - forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]" + forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev,rl]" displayName: 'Install PrimAITE' condition: eq( variables['Agent.OS'], 'Windows_NT' ) diff --git a/.github/workflows/build-sphinx.yml b/.github/workflows/build-sphinx.yml index 82da1c6b..da20fbd3 100644 --- a/.github/workflows/build-sphinx.yml +++ b/.github/workflows/build-sphinx.yml @@ -49,7 +49,7 @@ jobs: - name: Install PrimAITE for docs autosummary run: | set -x - python -m pip install -e .[dev] + python -m pip install -e .[dev,rl] - name: Run build script for Sphinx pages env: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ed94ad97..1b85f4be 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -48,7 +48,7 @@ jobs: - name: Install PrimAITE run: | PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) - python -m pip install $PRIMAITE_WHEEL[dev] + python -m pip install $PRIMAITE_WHEEL[dev,rl] - name: Perform PrimAITE Setup run: | diff --git a/README.md b/README.md index 3fd73b53..68a8488b 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ cd ~\primaite python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory .\.venv\Scripts\activate -pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl +pip install primaite-3.0.0-py3-none-any.whl[rl] primaite setup ``` @@ -66,7 +66,7 @@ mkdir ~/primaite cd ~/primaite python3 -m venv .venv source .venv/bin/activate -pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl +pip install primaite-3.0.0-py3-none-any.whl[rl] primaite setup ``` @@ -105,7 +105,7 @@ source venv/bin/activate #### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies ```bash -python3 -m pip install -e .[dev] +python3 -m pip install -e .[dev,rl] ``` #### 6. Perform the PrimAITE setup: @@ -114,6 +114,9 @@ python3 -m pip install -e .[dev] primaite setup ``` +#### Note +*It is possible to install PrimAITE without Ray RLLib, StableBaselines3, or any deep learning libraries by omitting the `rl` flag in the pip install command.* + ### Running PrimAITE Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`. diff --git a/docs/_static/firewall_acl.png b/docs/_static/firewall_acl.png index 1cdd2526..1e596575 100644 Binary files a/docs/_static/firewall_acl.png and b/docs/_static/firewall_acl.png differ diff --git a/docs/api.rst b/docs/api.rst index 13f3a1ec..e74be627 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,4 +19,3 @@ :recursive: primaite - tests diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 6b7d7542..9283f3f4 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -82,7 +82,7 @@ Install PrimAITE .. code-block:: bash :caption: Unix - pip install path/to/your/primaite.whl + pip install path/to/your/primaite.whl[rl] .. code-block:: powershell :caption: Windows (Powershell) @@ -135,12 +135,12 @@ For example: .. code-block:: bash :caption: Unix - pip install -e .[dev] + pip install -e .[dev,rl] .. code-block:: powershell :caption: Windows (Powershell) - pip install -e .[dev] + pip install -e .[dev,rl] To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`). diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 00b2dc79..f253d10e 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -66,4 +66,7 @@ Glossary The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours. Gymnasium - PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards. \ No newline at end of file + PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards. + + User app home + PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite/` on linux/darwin and `C:\\Users\\\\primaite` on Windows. diff --git a/pyproject.toml b/pyproject.toml index 5d913e1a..d01299be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,12 +36,10 @@ dependencies = [ "polars==0.18.4", "prettytable==3.8.0", "PyYAML==6.0", - "stable-baselines3[extra]==2.1.0", - "tensorflow==2.12.0", "typer[all]==0.9.0", "pydantic==2.7.0", - "ray[rllib] >= 2.9, < 3", - "ipywidgets" + "ipywidgets", + "deepdiff" ] [tool.setuptools.dynamic] @@ -55,6 +53,11 @@ license-files = ["LICENSE"] [project.optional-dependencies] +rl = [ + "ray[rllib] >= 2.20.0, < 3", + "tensorflow==2.12.0", + "stable-baselines3[extra]==2.1.0", +] dev = [ "build==0.10.0", "flake8==6.0.0", diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index cd4a1c29..444aa4f7 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: pass -class AgentActionHistoryItem(BaseModel): +class AgentHistoryItem(BaseModel): """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" timestep: int @@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel): response: RequestResponse """The response sent back by the simulator for this action.""" + reward: Optional[float] = None + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -110,7 +112,7 @@ class AbstractAgent(ABC): self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() - self.action_history: List[AgentActionHistoryItem] = [] + self.history: List[AgentHistoryItem] = [] def update_observation(self, state: Dict) -> ObsType: """ @@ -130,7 +132,7 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state=state, last_action_response=self.action_history[-1]) + return self.reward_function.update(state=state, last_action_response=self.history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -161,12 +163,16 @@ class AbstractAgent(ABC): self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse ) -> None: """Process the response from the most recent action.""" - self.action_history.append( - AgentActionHistoryItem( + self.history.append( + AgentHistoryItem( timestep=timestep, action=action, parameters=parameters, request=request, response=response ) ) + def save_reward_to_history(self) -> None: + """Update the most recent history item with the reward value.""" + self.history[-1].reward = self.reward_function.current_reward + class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 0222bfcc..d77640d1 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -34,7 +34,7 @@ from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE if TYPE_CHECKING: - from primaite.game.agent.interface import AgentActionHistoryItem + from primaite.game.agent.interface import AgentHistoryItem _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] @@ -44,7 +44,7 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -64,7 +64,7 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] self._last_request_failed: bool = False - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] self._last_request_failed: bool = False - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -343,7 +343,7 @@ class SharedReward(AbstractReward): self.callback: Callable[[str], float] = default_callback """Method that retrieves an agent's current reward given the agent's name.""" - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it.""" return self.callback(self.agent_name) @@ -389,7 +389,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ea5b3831..772ab5aa 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -160,6 +160,7 @@ class PrimaiteGame: agent = self.agents[agent_name] if self.step_counter > 0: # can't get reward before first action agent.update_reward(state=state) + agent.save_reward_to_history() agent.update_observation(state=state) # order of this doesn't matter so just use reward order agent.reward_function.total_reward += agent.reward_function.current_reward diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index 1b016bb8..21d67bab 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -22,7 +22,7 @@ "# Imports\n", "\n", "from primaite.config.load import data_manipulation_config_path\n", - "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.game.agent.interface import AgentHistoryItem\n", "from primaite.session.environment import PrimaiteGymEnv\n", "import yaml\n", "from pprint import pprint" @@ -63,7 +63,7 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index c6939afd..a3dab962 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -392,7 +392,7 @@ "# Imports\n", "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.game.agent.interface import AgentHistoryItem\n", "import yaml\n", "from pprint import pprint\n" ] @@ -444,7 +444,7 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index df688146..5ffb19ad 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -25,13 +25,13 @@ "from primaite.game.game import PrimaiteGame\n", "import yaml\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from primaite import PRIMAITE_PATHS\n", "\n", "import ray\n", "from ray import air, tune\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", - "from primaite.session.environment import PrimaiteRayMARLEnv\n", + "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", "\n", "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", "# to copy the files to your user data path.\n", @@ -60,8 +60,8 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", - " .rollouts(num_rollout_workers=0)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", " )\n" ] diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index 1fb66405..fbc5f4c6 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -18,8 +18,7 @@ "import yaml\n", "from primaite.config.load import data_manipulation_config_path\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv\n", - "from ray.rllib.algorithms import ppo\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray import air, tune\n", "import ray\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", @@ -52,8 +51,8 @@ "\n", "config = (\n", " PPOConfig()\n", - " .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n", - " .rollouts(num_rollout_workers=0)\n", + " .environment(env=PrimaiteRayEnv, env_config=env_config)\n", + " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", ")\n" ] @@ -74,7 +73,7 @@ "tune.Tuner(\n", " \"PPO\",\n", " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 5 * 128}\n", + " stop={\"timesteps_total\": 512}\n", " ),\n", " param_space=config\n", ").fit()\n" @@ -97,7 +96,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 9faf5820..1e247e81 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -43,7 +43,10 @@ "outputs": [], "source": [ "with open(data_manipulation_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)" + " cfg = yaml.safe_load(f)\n", + "for agent in cfg['agents']:\n", + " if agent['ref'] == 'defender':\n", + " agent['agent_settings']['flatten_obs']=True" ] }, { @@ -177,7 +180,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Using-Episode-Schedules.ipynb b/src/primaite/notebooks/Using-Episode-Schedules.ipynb index c44339ae..fc9e04f7 100644 --- a/src/primaite/notebooks/Using-Episode-Schedules.ipynb +++ b/src/primaite/notebooks/Using-Episode-Schedules.ipynb @@ -254,8 +254,8 @@ "table = PrettyTable()\n", "table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n", "for i in range(21):\n", - " green_action = env.game.agents['green_A'].action_history[i].action\n", - " red_action = env.game.agents['red_A'].action_history[i].action\n", + " green_action = env.game.agents['green_A'].history[i].action\n", + " red_action = env.game.agents['red_A'].history[i].action\n", " table.add_row([i, green_action, red_action])\n", "print(table)" ] @@ -285,8 +285,8 @@ "table = PrettyTable()\n", "table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n", "for i in range(21):\n", - " green_action = env.game.agents['green_B'].action_history[i].action\n", - " red_action = env.game.agents['red_B'].action_history[i].action\n", + " green_action = env.game.agents['green_B'].history[i].action\n", + " red_action = env.game.agents['red_B'].history[i].action\n", " table.add_row([i, green_action, red_action])\n", "print(table)" ] diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 4d0544e9..477efa9b 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union import gymnasium from gymnasium.core import ActType, ObsType -from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite import getLogger from primaite.game.agent.interface import ProxyAgent @@ -12,6 +11,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT +from primaite.simulator.system.core.packet_capture import PacketCapture _LOGGER = getLogger(__name__) @@ -61,7 +61,7 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = { - "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} + "agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()} } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(step, action, state, reward) @@ -90,9 +90,10 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) self.episode_counter += 1 + PacketCapture.clear() self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter)) self.game.setup_for_episode(episode=self.episode_counter) state = self.game.get_sim_state() @@ -126,166 +127,5 @@ class PrimaiteGymEnv(gymnasium.Env): def close(self): """Close the simulation.""" if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) - - -class PrimaiteRayEnv(gymnasium.Env): - """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" - - def __init__(self, env_config: Dict) -> None: - """Initialise the environment. - - :param env_config: A dictionary containing the environment configuration. - :type env_config: Dict - """ - self.env = PrimaiteGymEnv(env_config=env_config) - # self.env.episode_counter -= 1 - self.action_space = self.env.action_space - self.observation_space = self.env.observation_space - - def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: - """Reset the environment.""" - return self.env.reset(seed=seed) - - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: - """Perform a step in the environment.""" - return self.env.step(action) - - def close(self): - """Close the simulation.""" - self.env.close() - - @property - def game(self) -> PrimaiteGame: - """Pass through game from env.""" - return self.env.game - - -class PrimaiteRayMARLEnv(MultiAgentEnv): - """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" - - def __init__(self, env_config: Dict) -> None: - """Initialise the environment. - - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict - """ - self.episode_counter: int = 0 - """Current episode number.""" - self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) - """Object that returns a config corresponding to the current episode.""" - self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) - """Handles IO for the environment. This produces sys logs, agent logs, etc.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) - """Reference to the primaite game""" - self._agent_ids = list(self.game.rl_agents.keys()) - """Agent ids. This is a list of strings of agent names.""" - - self.terminateds = set() - self.truncateds = set() - self.observation_space = gymnasium.spaces.Dict( - { - name: gymnasium.spaces.flatten_space(agent.observation_manager.space) - for name, agent in self.agents.items() - } - ) - self.action_space = gymnasium.spaces.Dict( - {name: agent.action_manager.space for name, agent in self.agents.items()} - ) - - super().__init__() - - @property - def agents(self) -> Dict[str, ProxyAgent]: - """Grab a fresh reference to the agents from this episode's game object.""" - return {name: self.game.rl_agents[name] for name in self._agent_ids} - - def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: - """Reset the environment.""" - rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} - _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") - - if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) - - self.episode_counter += 1 - self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) - self.game.setup_for_episode(episode=self.episode_counter) - state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() - info = {} - return next_obs, info - - def step( - self, actions: Dict[str, ActType] - ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: - """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. - - :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. - :type actions: Dict[str, ActType] - :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent - identifier. - :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] - """ - step = self.game.step_counter - # 1. Perform actions - for agent_name, action in actions.items(): - self.agents[agent_name].store_action(action) - self.game.pre_timestep() - self.game.apply_agent_actions() - - # 2. Advance timestep - self.game.advance_timestep() - - # 3. Get next observations - state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() - - # 4. Get rewards - rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} - _LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}") - terminateds = {name: False for name, _ in self.agents.items()} - truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {name: {} for name, _ in self.agents.items()} - terminateds["__all__"] = len(self.terminateds) == len(self.agents) - truncateds["__all__"] = self.game.calculate_truncated() - if self.game.save_step_metadata: - self._write_step_metadata_json(step, actions, state, rewards) - return next_obs, rewards, terminateds, truncateds, infos - - def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" - - output_dir.mkdir(parents=True, exist_ok=True) - path = output_dir / f"step_{step}.json" - - data = { - "episode": self.episode_counter, - "step": step, - "actions": {agent_name: int(action) for agent_name, action in actions.items()}, - "reward": rewards, - "state": state, - } - with open(path, "w") as file: - json.dump(data, file) - - def _get_obs(self) -> Dict[str, ObsType]: - """Return the current observation.""" - obs = {} - for agent_name in self._agent_ids: - agent = self.game.rl_agents[agent_name] - unflat_space = agent.observation_manager.space - unflat_obs = agent.observation_manager.current_observation - obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) - return obs - - def close(self): - """Close the simulation.""" - if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 8bbc1b07..2901457f 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -87,7 +87,7 @@ class PrimaiteIO: """Return the path where agent actions will be saved.""" return self.session_path / "agent_actions" / f"episode_{episode}.json" - def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None: + def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None: """Take the contents of the agent action log and write it to a file. :param episode: Episode number diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py new file mode 100644 index 00000000..f9ab3405 --- /dev/null +++ b/src/primaite/session/ray_envs.py @@ -0,0 +1,177 @@ +import json +from typing import Dict, SupportsFloat, Tuple + +import gymnasium +from gymnasium.core import ActType, ObsType +from ray.rllib.env.multi_agent_env import MultiAgentEnv + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.session.environment import _LOGGER, PrimaiteGymEnv +from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler +from primaite.session.io import PrimaiteIO +from primaite.simulator import SIM_OUTPUT +from primaite.simulator.system.core.packet_capture import PacketCapture + + +class PrimaiteRayMARLEnv(MultiAgentEnv): + """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" + + def __init__(self, env_config: Dict) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` + which is the PrimaiteGame instance. + :type env_config: Dict + """ + self.episode_counter: int = 0 + """Current episode number.""" + self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) + """Object that returns a config corresponding to the current episode.""" + self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) + """Reference to the primaite game""" + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + + self.terminateds = set() + self.truncateds = set() + self.observation_space = gymnasium.spaces.Dict( + { + name: gymnasium.spaces.flatten_space(agent.observation_manager.space) + for name, agent in self.agents.items() + } + ) + self.action_space = gymnasium.spaces.Dict( + {name: agent.action_manager.space for name, agent in self.agents.items()} + ) + self._obs_space_in_preferred_format = True + self._action_space_in_preferred_format = True + super().__init__() + + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} + _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") + + if self.io.settings.save_agent_actions: + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) + + self.episode_counter += 1 + PacketCapture.clear() + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) + self.game.setup_for_episode(episode=self.episode_counter) + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + info = {} + return next_obs, info + + def step( + self, actions: Dict[str, ActType] + ) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]: + """Perform a step in the environment. Adherent to Ray MultiAgentEnv step API. + + :param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance. + :type actions: Dict[str, ActType] + :return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent + identifier. + :rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict] + """ + step = self.game.step_counter + # 1. Perform actions + for agent_name, action in actions.items(): + self.agents[agent_name].store_action(action) + self.game.pre_timestep() + self.game.apply_agent_actions() + + # 2. Advance timestep + self.game.advance_timestep() + + # 3. Get next observations + state = self.game.get_sim_state() + self.game.update_agents(state) + next_obs = self._get_obs() + + # 4. Get rewards + rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} + _LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}") + terminateds = {name: False for name, _ in self.agents.items()} + truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} + infos = {name: {} for name, _ in self.agents.items()} + terminateds["__all__"] = len(self.terminateds) == len(self.agents) + truncateds["__all__"] = self.game.calculate_truncated() + if self.game.save_step_metadata: + self._write_step_metadata_json(step, actions, state, rewards) + return next_obs, rewards, terminateds, truncateds, infos + + def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict): + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{step}.json" + + data = { + "episode": self.episode_counter, + "step": step, + "actions": {agent_name: int(action) for agent_name, action in actions.items()}, + "reward": rewards, + "state": state, + } + with open(path, "w") as file: + json.dump(data, file) + + def _get_obs(self) -> Dict[str, ObsType]: + """Return the current observation.""" + obs = {} + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] + unflat_space = agent.observation_manager.space + unflat_obs = agent.observation_manager.current_observation + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + return obs + + def close(self): + """Close the simulation.""" + if self.io.settings.save_agent_actions: + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) + + +class PrimaiteRayEnv(gymnasium.Env): + """Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray.""" + + def __init__(self, env_config: Dict) -> None: + """Initialise the environment. + + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict + """ + self.env = PrimaiteGymEnv(env_config=env_config) + # self.env.episode_counter -= 1 + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: + """Reset the environment.""" + return self.env.reset(seed=seed) + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]: + """Perform a step in the environment.""" + return self.env.step(action) + + def close(self): + """Close the simulation.""" + self.env.close() + + @property + def game(self) -> PrimaiteGame: + """Pass through game from env.""" + return self.env.game diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index cf38e94b..bc8a0584 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -21,6 +21,8 @@ class PacketCapture: The PCAPs are logged to: //__pcap.log """ + _logger_instances: List[logging.Logger] = [] + def __init__( self, hostname: str, @@ -65,10 +67,12 @@ class PacketCapture: if outbound: self.outbound_logger = logging.getLogger(self._get_logger_name(outbound)) + PacketCapture._logger_instances.append(self.outbound_logger) logger = self.outbound_logger else: self.inbound_logger = logging.getLogger(self._get_logger_name(outbound)) logger = self.inbound_logger + PacketCapture._logger_instances.append(self.inbound_logger) logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs logger.addHandler(file_handler) @@ -122,3 +126,13 @@ class PacketCapture: if SIM_OUTPUT.save_pcap_logs: msg = frame.model_dump_json() self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL + + @staticmethod + def clear(): + """Close all open PCAP file handlers.""" + for logger in PacketCapture._logger_instances: + handlers = logger.handlers[:] + for handler in handlers: + logger.removeHandler(handler) + handler.close() + PacketCapture._logger_instances = [] diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 712a16c4..9b550dd2 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -3,7 +3,7 @@ import yaml from ray import air, tune from ray.rllib.algorithms.ppo import PPOConfig -from primaite.session.environment import PrimaiteRayMARLEnv +from primaite.session.ray_envs import PrimaiteRayMARLEnv from tests import TEST_ASSETS_ROOT MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index d9057fef..f56f0f85 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -8,7 +8,7 @@ from ray.rllib.algorithms import ppo from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame -from primaite.session.environment import PrimaiteRayEnv +from primaite.session.ray_envs import PrimaiteRayEnv @pytest.mark.skip(reason="Slow, reenable later") diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index accfad50..0a2c6add 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -4,7 +4,8 @@ import yaml from gymnasium.core import ObsType from numpy import ndarray -from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayMARLEnv from primaite.simulator.network.hardware.nodes.host.server import Printer from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py index 6b40fb1a..c6fd1a2f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py +++ b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py @@ -1,7 +1,8 @@ import pytest import yaml -from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv from tests.conftest import TEST_ASSETS_ROOT folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders" diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 7c38057e..dff536de 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,6 +1,6 @@ import yaml -from primaite.game.agent.interface import AgentActionHistoryItem +from primaite.game.agent.interface import AgentHistoryItem from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv @@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate( state, - last_action_response=AgentActionHistoryItem( + last_action_response=AgentHistoryItem( timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response ), ) @@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate( state, - last_action_response=AgentActionHistoryItem( + last_action_response=AgentHistoryItem( timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response ), )