diff --git a/docs/_static/firewall_acl.png b/docs/_static/firewall_acl.png index 1cdd2526..1e596575 100644 Binary files a/docs/_static/firewall_acl.png and b/docs/_static/firewall_acl.png differ diff --git a/docs/api.rst b/docs/api.rst index 13f3a1ec..e74be627 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,4 +19,3 @@ :recursive: primaite - tests diff --git a/docs/index.rst b/docs/index.rst index a0f302e9..8e7defb1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -125,7 +125,6 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/state_system source/request_system PrimAITE API - PrimAITE Tests .. toctree:: diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 67fd7aaa..c322caac 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -78,4 +78,4 @@ Glossary PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards. User app home - PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite` on linux/darwin and `C:\\Users\\\\primaite\\` on Windows. + PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite/` on linux/darwin and `C:\\Users\\\\primaite` on Windows. diff --git a/pyproject.toml b/pyproject.toml index 9c94a388..d01299be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ - "ray[rllib] >= 2.9, < 3", + "ray[rllib] >= 2.20.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", ] diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index cd4a1c29..444aa4f7 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: pass -class AgentActionHistoryItem(BaseModel): +class AgentHistoryItem(BaseModel): """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" timestep: int @@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel): response: RequestResponse """The response sent back by the simulator for this action.""" + reward: Optional[float] = None + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -110,7 +112,7 @@ class AbstractAgent(ABC): self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() - self.action_history: List[AgentActionHistoryItem] = [] + self.history: List[AgentHistoryItem] = [] def update_observation(self, state: Dict) -> ObsType: """ @@ -130,7 +132,7 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state=state, last_action_response=self.action_history[-1]) + return self.reward_function.update(state=state, last_action_response=self.history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -161,12 +163,16 @@ class AbstractAgent(ABC): self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse ) -> None: """Process the response from the most recent action.""" - self.action_history.append( - AgentActionHistoryItem( + self.history.append( + AgentHistoryItem( timestep=timestep, action=action, parameters=parameters, request=request, response=response ) ) + def save_reward_to_history(self) -> None: + """Update the most recent history item with the reward value.""" + self.history[-1].reward = self.reward_function.current_reward + class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 0222bfcc..d77640d1 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -34,7 +34,7 @@ from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE if TYPE_CHECKING: - from primaite.game.agent.interface import AgentActionHistoryItem + from primaite.game.agent.interface import AgentHistoryItem _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] @@ -44,7 +44,7 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -64,7 +64,7 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] self._last_request_failed: bool = False - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] self._last_request_failed: bool = False - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ Calculate the reward based on current simulation state, and the recent agent action. @@ -343,7 +343,7 @@ class SharedReward(AbstractReward): self.callback: Callable[[str], float] = default_callback """Method that retrieves an agent's current reward given the agent's name.""" - def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it.""" return self.callback(self.agent_name) @@ -389,7 +389,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ea5b3831..772ab5aa 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -160,6 +160,7 @@ class PrimaiteGame: agent = self.agents[agent_name] if self.step_counter > 0: # can't get reward before first action agent.update_reward(state=state) + agent.save_reward_to_history() agent.update_observation(state=state) # order of this doesn't matter so just use reward order agent.reward_function.total_reward += agent.reward_function.current_reward diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index 1b016bb8..21d67bab 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -22,7 +22,7 @@ "# Imports\n", "\n", "from primaite.config.load import data_manipulation_config_path\n", - "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.game.agent.interface import AgentHistoryItem\n", "from primaite.session.environment import PrimaiteGymEnv\n", "import yaml\n", "from pprint import pprint" @@ -63,7 +63,7 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 8104149e..376b7f28 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -392,7 +392,7 @@ "# Imports\n", "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.game.agent.interface import AgentHistoryItem\n", "import yaml\n", "from pprint import pprint\n" ] @@ -444,7 +444,7 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", @@ -705,7 +705,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index 65b1595f..5ffb19ad 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -25,7 +25,7 @@ "from primaite.game.game import PrimaiteGame\n", "import yaml\n", "\n", - "from primaite.session.environment import PrimaiteRayEnv\n", + "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from primaite import PRIMAITE_PATHS\n", "\n", "import ray\n", @@ -60,8 +60,8 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", - " .rollouts(num_rollout_workers=0)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n", + " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", " )\n" ] diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index 9d458426..fbc5f4c6 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -19,7 +19,6 @@ "from primaite.config.load import data_manipulation_config_path\n", "\n", "from primaite.session.ray_envs import PrimaiteRayEnv\n", - "from ray.rllib.algorithms import ppo\n", "from ray import air, tune\n", "import ray\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", @@ -52,8 +51,8 @@ "\n", "config = (\n", " PPOConfig()\n", - " .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n", - " .rollouts(num_rollout_workers=0)\n", + " .environment(env=PrimaiteRayEnv, env_config=env_config)\n", + " .env_runners(num_env_runners=0)\n", " .training(train_batch_size=128)\n", ")\n" ] @@ -74,7 +73,7 @@ "tune.Tuner(\n", " \"PPO\",\n", " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 5 * 128}\n", + " stop={\"timesteps_total\": 512}\n", " ),\n", " param_space=config\n", ").fit()\n" diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 9faf5820..1e247e81 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -43,7 +43,10 @@ "outputs": [], "source": [ "with open(data_manipulation_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)" + " cfg = yaml.safe_load(f)\n", + "for agent in cfg['agents']:\n", + " if agent['ref'] == 'defender':\n", + " agent['agent_settings']['flatten_obs']=True" ] }, { @@ -177,7 +180,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Using-Episode-Schedules.ipynb b/src/primaite/notebooks/Using-Episode-Schedules.ipynb index b0669472..062c7135 100644 --- a/src/primaite/notebooks/Using-Episode-Schedules.ipynb +++ b/src/primaite/notebooks/Using-Episode-Schedules.ipynb @@ -298,8 +298,8 @@ "table = PrettyTable()\n", "table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n", "for i in range(21):\n", - " green_action = env.game.agents['green_A'].action_history[i].action\n", - " red_action = env.game.agents['red_A'].action_history[i].action\n", + " green_action = env.game.agents['green_A'].history[i].action\n", + " red_action = env.game.agents['red_A'].history[i].action\n", " table.add_row([i, green_action, red_action])\n", "print(table)" ] @@ -329,8 +329,8 @@ "table = PrettyTable()\n", "table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n", "for i in range(21):\n", - " green_action = env.game.agents['green_B'].action_history[i].action\n", - " red_action = env.game.agents['red_B'].action_history[i].action\n", + " green_action = env.game.agents['green_B'].history[i].action\n", + " red_action = env.game.agents['red_B'].history[i].action\n", " table.add_row([i, green_action, red_action])\n", "print(table)" ] diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index edb8a476..477efa9b 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -11,6 +11,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT +from primaite.simulator.system.core.packet_capture import PacketCapture _LOGGER = getLogger(__name__) @@ -60,7 +61,7 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = { - "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} + "agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()} } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(step, action, state, reward) @@ -89,9 +90,10 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) self.episode_counter += 1 + PacketCapture.clear() self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter)) self.game.setup_for_episode(episode=self.episode_counter) state = self.game.get_sim_state() @@ -125,5 +127,5 @@ class PrimaiteGymEnv(gymnasium.Env): def close(self): """Close the simulation.""" if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 8bbc1b07..2901457f 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -87,7 +87,7 @@ class PrimaiteIO: """Return the path where agent actions will be saved.""" return self.session_path / "agent_actions" / f"episode_{episode}.json" - def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None: + def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None: """Take the contents of the agent action log and write it to a file. :param episode: Episode number diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 5149a225..f9ab3405 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -11,6 +11,7 @@ from primaite.session.environment import _LOGGER, PrimaiteGymEnv from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT +from primaite.simulator.system.core.packet_capture import PacketCapture class PrimaiteRayMARLEnv(MultiAgentEnv): @@ -45,7 +46,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} ) - + self._obs_space_in_preferred_format = True + self._action_space_in_preferred_format = True super().__init__() @property @@ -59,10 +61,11 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) self.episode_counter += 1 + PacketCapture.clear() self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) self.game.setup_for_episode(episode=self.episode_counter) state = self.game.get_sim_state() @@ -138,8 +141,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def close(self): """Close the simulation.""" if self.io.settings.save_agent_actions: - all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} - self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()} + self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter) class PrimaiteRayEnv(gymnasium.Env): diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index cf38e94b..bc8a0584 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -21,6 +21,8 @@ class PacketCapture: The PCAPs are logged to: //__pcap.log """ + _logger_instances: List[logging.Logger] = [] + def __init__( self, hostname: str, @@ -65,10 +67,12 @@ class PacketCapture: if outbound: self.outbound_logger = logging.getLogger(self._get_logger_name(outbound)) + PacketCapture._logger_instances.append(self.outbound_logger) logger = self.outbound_logger else: self.inbound_logger = logging.getLogger(self._get_logger_name(outbound)) logger = self.inbound_logger + PacketCapture._logger_instances.append(self.inbound_logger) logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs logger.addHandler(file_handler) @@ -122,3 +126,13 @@ class PacketCapture: if SIM_OUTPUT.save_pcap_logs: msg = frame.model_dump_json() self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL + + @staticmethod + def clear(): + """Close all open PCAP file handlers.""" + for logger in PacketCapture._logger_instances: + handlers = logger.handlers[:] + for handler in handlers: + logger.removeHandler(handler) + handler.close() + PacketCapture._logger_instances = [] diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 7c38057e..dff536de 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,6 +1,6 @@ import yaml -from primaite.game.agent.interface import AgentActionHistoryItem +from primaite.game.agent.interface import AgentHistoryItem from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv @@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate( state, - last_action_response=AgentActionHistoryItem( + last_action_response=AgentHistoryItem( timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response ), ) @@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent): state = game.get_sim_state() reward_value = comp.calculate( state, - last_action_response=AgentActionHistoryItem( + last_action_response=AgentHistoryItem( timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response ), )