From 922298eaf02f2ab7897f7523aa9bf3122add51f4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 20:07:02 +0000 Subject: [PATCH 01/19] Make database admin action possible --- .../config/_package_data/example_config.yaml | 69 ++++++++++++++++--- src/primaite/game/agent/rewards.py | 41 +++++++++++ .../system/applications/database_client.py | 30 +++++++- 3 files changed, 129 insertions(+), 11 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f85baf10..a0e9667e 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -33,7 +33,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: GreenUC2Agent observation_space: type: UC2GreenObservation action_space: @@ -45,24 +45,39 @@ agents: - node_name: client_2 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - type: DUMMY agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: GreenUC2Agent observation_space: type: UC2GreenObservation action_space: @@ -74,14 +89,36 @@ agents: - node_name: client_1 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + reward_function: reward_components: - type: DUMMY + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + @@ -572,6 +609,14 @@ agents: weight: 0.33 options: node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.1 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.1 + options: + node_hostname: client_2 agent_settings: @@ -717,6 +762,10 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_1_dns_client type: DNSClient @@ -740,6 +789,10 @@ simulation: data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index b5d5f998..acc37711 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -242,6 +242,46 @@ class WebpageUnavailablePenalty(AbstractReward): return cls(node_hostname=node_hostname) +class GreenAdminDatabaseUnreachablePenalty(AbstractReward): + """Penalises the agent when the green db clients fail to connect to the database.""" + + def __init__(self, node_hostname: str) -> None: + """ + Initialise the reward component. + + :param node_hostname: Hostname of the node where the database client sits. + :type node_hostname: str + """ + self._node = node_hostname + self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + + def calculate(self, state: Dict) -> float: + """ + Calculate the reward based on current simulation state. + + :param state: The current state of the simulation. + :type state: Dict + """ + db_state = access_from_nested_dict(state, self.location_in_state) + if db_state is NOT_PRESENT_IN_STATE or "connections_status" not in db_state: + _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + connections_status = db_state["connections_status"] + if False in connections_status: + return -1.0 + return 0 + + @classmethod + def from_config(cls, config: Dict) -> AbstractReward: + """ + Build the reward component object from config. + + :param config: Configuration dictionary. + :type config: Dict + """ + node_hostname = config.get("node_hostname") + return cls(node_hostname=node_hostname) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -250,6 +290,7 @@ class RewardFunction: "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, + "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, } def __init__(self): diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 50d9f3d4..67c0c9b4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -1,8 +1,9 @@ from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from uuid import uuid4 from primaite import getLogger +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application @@ -25,6 +26,8 @@ class DatabaseClient(Application): server_password: Optional[str] = None connected: bool = False _query_success_tracker: Dict[str, bool] = {} + _connections_status: List[bool] = [] + """Keep track of connections that were established or verified during this step. Used for rewards.""" def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" @@ -33,6 +36,20 @@ class DatabaseClient(Application): super().__init__(**kwargs) self.set_original_state() + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + rm.add_request("execute", RequestType(func=lambda request, context: self.execute())) + return rm + + def execute(self) -> bool: + """Execution definition for db client: perform a select query.""" + if self.connections: + can_connect = self.connect(connection_id=list(self.connections.keys())[-1]) + else: + can_connect = self.connect() + self._connections_status.append(can_connect) + return can_connect + def set_original_state(self): """Sets the original state.""" _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") @@ -52,8 +69,11 @@ class DatabaseClient(Application): :return: A dictionary representing the current state. """ - pass - return super().describe_state() + state = super().describe_state() + # list of connections that were established or verified during this step. + state["connections_status"] = [c for c in self._connections_status] + self._connections_status.clear() + return state def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): """ @@ -74,6 +94,10 @@ class DatabaseClient(Application): if not connection_id: connection_id = str(uuid4()) + # if we are reusing a connection_id, remove it from self.connections so that its new status can be populated + # warning: janky + self._connections.pop(connection_id, None) + self.connected = self._connect( server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id ) From c54f82fb1bd8c961bc4d7a250e8ea9572fb44b33 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 26 Feb 2024 20:08:13 +0000 Subject: [PATCH 02/19] Start implementing green agent logic for UC2 --- src/primaite/game/agent/scripted_agents.py | 62 +++++++++++++++++++++- src/primaite/game/game.py | 5 +- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py index 3748494b..a88e563d 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents.py @@ -1,10 +1,70 @@ """Agents with predefined behaviours.""" +from typing import Dict, Optional, Tuple + +import numpy as np +import pydantic +from gymnasium.core import ObsType + +from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.rewards import RewardFunction -class GreenWebBrowsingAgent(AbstractScriptedAgent): +class GreenUC2Agent(AbstractScriptedAgent): """Scripted agent which attempts to send web requests to a target node.""" + class GreenUC2AgentSettings(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="forbid") + action_probabilities: Dict[int, float] + """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" + random_seed: Optional[int] = None + + @pydantic.field_validator("action_probabilities", mode="after") + @classmethod + def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: + if not abs(sum(v.values()) - 1) < 1e-6: + raise ValueError(f"Green action probabilities must sum to 1") + return v + + @pydantic.field_validator("action_probabilities", mode="after") + @classmethod + def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]: + if not all((i in v) for i in range(len(v))): + raise ValueError( + "Green action probabilities must be defined as a mapping where the keys are consecutive integers " + "from 0 to N." + ) + + def __init__( + self, + agent_name: str, + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + settings: Dict = {}, + ) -> None: + # If the action probabilities are not specified, create equal probabilities for all actions + if "action_probabilities" not in settings: + num_actions = len(action_space.action_map) + settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} + + # If seed not specified, set it to None so that numpy chooses a random one. + settings.setdefault("random_seed") + + self.settings = GreenUC2Agent.GreenUC2AgentSettings(settings) + + self.rng = np.random.default_rng(self.settings.random_seed) + + # convert probabilities from + self.probabilities = np.array[self.settings.action_probabilities.values()] + + super().__init__(agent_name, action_space, observation_space, reward_function) + + def get_action(self, obs: ObsType, reward: float = 0) -> Tuple[str, Dict]: + choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) + return self.action_manager.get_action(choice) + raise NotImplementedError diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ed98accd..a9d564ba 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents import GreenUC2Agent from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -392,9 +393,9 @@ class PrimaiteGame: agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) # CREATE AGENT - if agent_type == "GreenWebBrowsingAgent": + if agent_type == "GreenUC2Agent": # TODO: implement non-random agents and fix this parsing - new_agent = RandomAgent( + new_agent = GreenUC2Agent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, From af8ca82fcbbb22a7c3d529b7f28f1befb1c20104 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 27 Feb 2024 13:30:16 +0000 Subject: [PATCH 03/19] Get the db admin green agent working --- .../config/_package_data/example_config.yaml | 19 +- .../example_config_2_rl_agents.yaml | 2 +- src/primaite/game/agent/actions.py | 43 ++-- .../game/agent/data_manipulation_bot.py | 10 +- src/primaite/game/agent/interface.py | 10 +- src/primaite/game/agent/scripted_agents.py | 43 ++-- src/primaite/game/game.py | 13 +- src/primaite/notebooks/uc2_demo.ipynb | 232 +++++++++++++++++- .../assets/configs/bad_primaite_session.yaml | 2 +- .../configs/basic_switched_network.yaml | 2 +- .../configs/eval_only_primaite_session.yaml | 2 +- tests/assets/configs/multi_agent_session.yaml | 2 +- .../assets/configs/test_primaite_session.yaml | 2 +- .../configs/train_only_primaite_session.yaml | 2 +- tests/conftest.py | 5 +- .../_game/_agent/test_probabilistic_agent.py | 84 +++++++ 16 files changed, 386 insertions(+), 87 deletions(-) create mode 100644 tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index a0e9667e..6813161d 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -33,7 +33,12 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenUC2Agent + type: probabilistic_agent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -69,15 +74,14 @@ agents: reward_components: - type: DUMMY + - ref: client_1_green_user + team: GREEN + type: probabilistic_agent agent_settings: action_probabilities: 0: 0.3 1: 0.6 2: 0.1 - - - ref: client_1_green_user - team: GREEN - type: GreenUC2Agent observation_space: type: UC2GreenObservation action_space: @@ -113,11 +117,6 @@ agents: reward_components: - type: DUMMY - agent_settings: - action_probabilities: - 0: 0.3 - 1: 0.6 - 2: 0.1 diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 93019c9d..df6130d1 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -27,7 +27,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 1793d420..18cb6262 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -607,7 +607,6 @@ class ActionManager: def __init__( self, - game: "PrimaiteGame", # reference to game for information lookup actions: List[Dict], # stores list of actions available to agent nodes: List[Dict], # extra configuration for each node max_folders_per_node: int = 2, # allows calculating shape @@ -618,7 +617,7 @@ class ActionManager: max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port - ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. + ip_address_list: List[str] = [], # to allow us to map an index to an ip address. act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions ) -> None: """Init method for ActionManager. @@ -649,7 +648,6 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.game: "PrimaiteGame" = game self.node_names: List[str] = [n["node_name"] for n in nodes] """List of node names in this action space. The list order is the mapping between node index and node name.""" self.application_names: List[List[str]] = [] @@ -707,25 +705,7 @@ class ActionManager: self.protocols: List[str] = protocols self.ports: List[str] = ports - self.ip_address_list: List[str] - - # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from - # the nodes in the simulation. - # TODO: refactor. Options: - # 1: This should be pulled out into it's own function for clarity - # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to - # go through the nodes here. - if ip_address_list is not None: - self.ip_address_list = ip_address_list - else: - self.ip_address_list = [] - for node_name in self.node_names: - node_obj = self.game.simulation.network.get_node_by_hostname(node_name) - if node_obj is None: - continue - network_interfaces = node_obj.network_interfaces - for nic_uuid, nic_obj in network_interfaces.items(): - self.ip_address_list.append(nic_obj.ip_address) + self.ip_address_list: List[str] = ip_address_list # action_args are settings which are applied to the action space as a whole. global_action_args = { @@ -958,6 +938,12 @@ class ActionManager: :return: The constructed ActionManager. :rtype: ActionManager """ + # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from + # the nodes in the simulation. + # TODO: refactor. Options: + # 1: This should be pulled out into it's own function for clarity + # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to + # go through the nodes here. ip_address_order = cfg["options"].pop("ip_address_order", {}) ip_address_list = [] for entry in ip_address_order: @@ -967,13 +953,22 @@ class ActionManager: ip_address = node_obj.network_interface[nic_num].ip_address ip_address_list.append(ip_address) + if not ip_address_list: + node_names = [n["node_name"] for n in cfg.get("nodes", {})] + for node_name in node_names: + node_obj = game.simulation.network.get_node_by_hostname(node_name) + if node_obj is None: + continue + network_interfaces = node_obj.network_interfaces + for nic_uuid, nic_obj in network_interfaces.items(): + ip_address_list.append(nic_obj.ip_address) + obj = cls( - game=game, actions=cfg["action_list"], **cfg["options"], protocols=game.options.protocols, ports=game.options.ports, - ip_address_list=ip_address_list or None, + ip_address_list=ip_address_list, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 126c55ec..b5de9a5a 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -1,5 +1,5 @@ import random -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from gymnasium.core import ObsType @@ -26,7 +26,7 @@ class DataManipulationAgent(AbstractScriptedAgent): ) self.next_execution_timestep = timestep + random_timestep_increment - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -36,12 +36,10 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - current_timestep = self.action_manager.game.step_counter - - if current_timestep < self.next_execution_timestep: + if timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} - self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) + self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 276715f7..4f434bad 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -112,7 +112,7 @@ class AbstractAgent(ABC): return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -122,6 +122,8 @@ class AbstractAgent(ABC): :type obs: ObsType :param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted? :type reward: float, optional + :param timestep: The current timestep in the simulation, used for non-RL agents. Optional + :type timestep: int :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ @@ -144,13 +146,13 @@ class AbstractAgent(ABC): class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" - ... + pass class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. :param obs: _description_ @@ -183,7 +185,7 @@ class ProxyAgent(AbstractAgent): self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py index a88e563d..28d94062 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents.py @@ -11,30 +11,39 @@ from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -class GreenUC2Agent(AbstractScriptedAgent): - """Scripted agent which attempts to send web requests to a target node.""" +class ProbabilisticAgent(AbstractScriptedAgent): + """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" + + class Settings(pydantic.BaseModel): + """Config schema for Probabilistic agent settings.""" - class GreenUC2AgentSettings(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra="forbid") + """Strict validation.""" action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" random_seed: Optional[int] = None + """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" + # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way + # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. @pydantic.field_validator("action_probabilities", mode="after") @classmethod def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: + """Make sure the probabilities sum to 1.""" if not abs(sum(v.values()) - 1) < 1e-6: - raise ValueError(f"Green action probabilities must sum to 1") + raise ValueError("Green action probabilities must sum to 1") return v @pydantic.field_validator("action_probabilities", mode="after") @classmethod def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]: + """Ensure that the keys of the probability dictionary cover all integers from 0 to N.""" if not all((i in v) for i in range(len(v))): raise ValueError( "Green action probabilities must be defined as a mapping where the keys are consecutive integers " "from 0 to N." ) + return v def __init__( self, @@ -52,23 +61,27 @@ class GreenUC2Agent(AbstractScriptedAgent): # If seed not specified, set it to None so that numpy chooses a random one. settings.setdefault("random_seed") - self.settings = GreenUC2Agent.GreenUC2AgentSettings(settings) + self.settings = ProbabilisticAgent.Settings(**settings) self.rng = np.random.default_rng(self.settings.random_seed) # convert probabilities from - self.probabilities = np.array[self.settings.action_probabilities.values()] + self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) - def get_action(self, obs: ObsType, reward: float = 0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + """ + Choose a random action from the action space. + + The probability of each action is given by the corresponding index in ``self.probabilities``. + + :param obs: Current observation of the simulation + :type obs: ObsType + :param reward: Reward for the last step, not used for scripted agents, defaults to 0 + :type reward: float, optional + :return: Action to be taken in CAOS format. + :rtype: Tuple[str, Dict] + """ choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) return self.action_manager.get_action(choice) - - raise NotImplementedError - - -class RedDatabaseCorruptingAgent(AbstractScriptedAgent): - """Scripted agent which attempts to corrupt the database of the target node.""" - - raise NotImplementedError diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index a9d564ba..b44abe16 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,10 +7,10 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent +from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction -from primaite.game.agent.scripted_agents import GreenUC2Agent +from primaite.game.agent.scripted_agents import ProbabilisticAgent from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -165,7 +165,7 @@ class PrimaiteGame: for agent in self.agents: obs = agent.observation_manager.current_observation rew = agent.reward_function.current_reward - action_choice, options = agent.get_action(obs, rew) + action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter) agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) @@ -393,14 +393,15 @@ class PrimaiteGame: agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) # CREATE AGENT - if agent_type == "GreenUC2Agent": + if agent_type == "probabilistic_agent": # TODO: implement non-random agents and fix this parsing - new_agent = GreenUC2Agent( + settings = agent_cfg.get("agent_settings") + new_agent = ProbabilisticAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=reward_function, - agent_settings=agent_settings, + settings=settings, ) game.agents.append(new_agent) elif agent_type == "ProxyAgent": diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index c4fe4c9a..fa4a28a4 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -334,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [] }, @@ -346,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "tags": [] }, @@ -371,11 +371,150 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-02-27 09:43:39,312::WARNING::primaite.game.game::275::service type not found DatabaseClient\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting environment, episode 0, avg. reward: 0.0\n", + "env created successfully\n", + "{'ACL': {1: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 0,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 2: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 1,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 3: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 2,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 4: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 3,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 5: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 4,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 6: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 5,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 7: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 6,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 8: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 7,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 9: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 8,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 10: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 9,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0}},\n", + " 'ICS': 0,\n", + " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", + " 2: {'PROTOCOLS': {'ALL': 1}},\n", + " 3: {'PROTOCOLS': {'ALL': 1}},\n", + " 4: {'PROTOCOLS': {'ALL': 1}},\n", + " 5: {'PROTOCOLS': {'ALL': 1}},\n", + " 6: {'PROTOCOLS': {'ALL': 1}},\n", + " 7: {'PROTOCOLS': {'ALL': 1}},\n", + " 8: {'PROTOCOLS': {'ALL': 1}},\n", + " 9: {'PROTOCOLS': {'ALL': 1}},\n", + " 10: {'PROTOCOLS': {'ALL': 0}}},\n", + " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", + " 'health_status': 1}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1},\n", + " 2: {'nic_status': 0}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}}\n" + ] + } + ], "source": [ "# create the env\n", "with open(example_config_path(), 'r') as f:\n", @@ -403,7 +542,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -421,15 +560,57 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 211, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 212, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 213, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 214, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 215, Red action: DO NOTHING, Blue reward:-0.42\n", + "step: 216, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 217, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 218, Red action: DO NOTHING, Blue reward:-0.42\n", + "step: 219, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 220, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 221, Red action: ATTACK from client 2, Blue reward:-0.32\n", + "step: 222, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 223, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 224, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 225, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 226, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 227, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 228, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 229, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 230, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 231, Red action: DO NOTHING, Blue reward:-0.42\n", + "step: 232, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 233, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 234, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 235, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 236, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 237, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 238, Red action: ATTACK from client 2, Blue reward:-0.32\n", + "step: 239, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 240, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 241, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 242, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 243, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 244, Red action: DO NOTHING, Blue reward:-0.32\n", + "step: 245, Red action: DO NOTHING, Blue reward:-0.32\n" + ] + } + ], "source": [ "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", - " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" ] }, { @@ -509,7 +690,7 @@ "\n", "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n", "\n", - "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -523,8 +704,8 @@ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", "print(f\"Blue reward:{reward}\" )" ] }, @@ -582,6 +763,33 @@ "obs['ACL']" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = env.game.simulation.network" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dbc = net.get_node_by_hostname('client_1').software_manager.software.get('DatabaseClient')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dbc._query_success_tracker" + ] + }, { "cell_type": "code", "execution_count": null, @@ -606,7 +814,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 5bdc3273..892e6af7 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -21,7 +21,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index d1cec079..ad2ea787 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -33,7 +33,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 8361e318..9b668686 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 87bd9d1c..5a7d8366 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -31,7 +31,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 76190a64..42dd27fb 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -29,7 +29,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 5d004c7e..8a4a1178 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/conftest.py b/tests/conftest.py index 5084c339..2add835f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import pytest import yaml @@ -309,7 +309,7 @@ class ControlledAgent(AbstractAgent): ) self.most_recent_action: Tuple[str, Dict] - def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" return self.most_recent_action @@ -478,7 +478,6 @@ def game_and_agent(): ] action_space = ActionManager( - game=game, actions=actions, # ALL POSSIBLE ACTIONS nodes=[ { diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py new file mode 100644 index 00000000..f0b37cac --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -0,0 +1,84 @@ +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents import ProbabilisticAgent + + +def test_probabilistic_agent(): + """ + Check that the probabilistic agent selects actions with approximately the right probabilities. + + Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator), + we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values + were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number + generation. + """ + N_TRIALS = 10_000 + P_DO_NOTHING = 0.1 + P_NODE_APPLICATION_EXECUTE = 0.3 + P_NODE_FILE_DELETE = 0.6 + MIN_DO_NOTHING = 850 + MAX_DO_NOTHING = 1150 + MIN_NODE_APPLICATION_EXECUTE = 2800 + MAX_NODE_APPLICATION_EXECUTE = 3200 + MIN_NODE_FILE_DELETE = 5750 + MAX_NODE_FILE_DELETE = 6250 + + action_space = ActionManager( + actions=[ + {"type": "DONOTHING"}, + {"type": "NODE_APPLICATION_EXECUTE"}, + {"type": "NODE_FILE_DELETE"}, + ], + nodes=[ + { + "node_name": "client_1", + "applications": [{"application_name": "WebBrowser"}], + "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], + }, + ], + max_folders_per_node=2, + max_files_per_folder=2, + max_services_per_node=2, + max_applications_per_node=2, + max_nics_per_node=2, + max_acl_rules=10, + protocols=["TCP", "UDP", "ICMP"], + ports=["HTTP", "DNS", "ARP"], + act_map={ + 0: {"action": "DONOTHING", "options": {}}, + 1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}}, + 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, + }, + ) + observation_space = ObservationManager(ICSObservation()) + reward_function = RewardFunction() + + pa = ProbabilisticAgent( + agent_name="test_agent", + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + settings={ + "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + "random_seed": 120, + }, + ) + + do_nothing_count = 0 + node_application_execute_count = 0 + node_file_delete_count = 0 + for _ in range(N_TRIALS): + a = pa.get_action(0, timestep=0) + if a == ("DONOTHING", {}): + do_nothing_count += 1 + elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): + node_application_execute_count += 1 + elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}): + node_file_delete_count += 1 + else: + raise AssertionError("Probabilistic agent produced an unexpected action.") + + assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING + assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE + assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE From 2f3e40fb6b6abe943770119b109319bc0edb7266 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Feb 2024 13:22:05 +0000 Subject: [PATCH 04/19] Fix issue around reset --- src/primaite/game/game.py | 2 +- src/primaite/session/environment.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index eeb0d007..3b9a21d4 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -417,7 +417,7 @@ class PrimaiteGame: agent_settings=agent_settings, ) else: - msg(f"Configuration error: {agent_type} is not a valid agent type.") + msg = f"Configuration error: {agent_type} is not a valid agent type." _LOGGER.error(msg) raise ValueError(msg) game.agents[agent_cfg["ref"]] = new_agent diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index f8dbab9d..d54503a3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,3 +1,4 @@ +import copy import json from typing import Any, Dict, Optional, SupportsFloat, Tuple @@ -23,7 +24,7 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.game_config: Dict = game_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" @@ -78,7 +79,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() From bd0b2e003346482fb84a3ecac86c8100439382d2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Feb 2024 13:22:41 +0000 Subject: [PATCH 05/19] Remove redundant notebook cells --- src/primaite/notebooks/uc2_demo.ipynb | 33 +++------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index cf973905..13fb7d80 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -345,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "tags": [] }, @@ -357,7 +357,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "tags": [] }, @@ -412,7 +412,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -595,33 +595,6 @@ "env.reset()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "net = env.game.simulation.network" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dbc = net.get_node_by_hostname('client_1').software_manager.software.get('DatabaseClient')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dbc._query_success_tracker" - ] - }, { "cell_type": "code", "execution_count": null, From 10a40538876930afa371bac5c77691626917472b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 1 Mar 2024 15:14:00 +0000 Subject: [PATCH 06/19] Fix tests --- docs/source/configuration/agents.rst | 6 +++--- .../_package_data/example_config_2_rl_agents.yaml | 2 +- src/primaite/session/session.py | 9 +++------ tests/assets/configs/dmz_network.yaml | 2 +- tests/e2e_integration_tests/test_primaite_session.py | 10 +++++----- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index f32843b1..ac67c365 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo ... - ref: green_agent_example team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: @@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive ``type`` -------- -Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour. +Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``probabilistic_agent`` generate their own behaviour. Available agent types: -- ``GreenWebBrowsingAgent`` +- ``probabilistic_agent`` - ``ProxyAgent`` - ``RedDatabaseCorruptingAgent`` diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index d6d3f044..b6b07afa 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -64,7 +64,7 @@ agents: - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index b8f80e95..d244f6b0 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict -from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings @@ -40,7 +39,7 @@ class SessionMode(Enum): class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" - def __init__(self, game: PrimaiteGame): + def __init__(self, game_cfg: Dict): """Initialise PrimaiteSession object.""" self.training_options: TrainingOptions """Options specific to agent training.""" @@ -57,7 +56,7 @@ class PrimaiteSession: self.io_manager: Optional["SessionIO"] = None """IO manager for the session.""" - self.game: PrimaiteGame = game + self.game_cfg: Dict = game_cfg """Primaite Game object for managing main simulation loop and agents.""" def start_session(self) -> None: @@ -93,9 +92,7 @@ class PrimaiteSession: io_settings = cfg.get("io_settings", {}) io_manager = SessionIO(SessionIOSettings(**io_settings)) - game = PrimaiteGame.from_config(cfg) - - sess = cls(game=game) + sess = cls(game_cfg=cfg) sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 880735d9..56a68410 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -65,7 +65,7 @@ game: agents: - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 7785e4ae..da13dcd8 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -21,15 +21,15 @@ class TestPrimaiteSession: raise AssertionError assert session is not None - assert session.game.simulation - assert len(session.game.agents) == 3 - assert len(session.game.rl_agents) == 1 + assert session.env.game.simulation + assert len(session.env.game.agents) == 3 + assert len(session.env.game.rl_agents) == 1 assert session.policy assert session.env - assert session.game.simulation.network - assert len(session.game.simulation.network.nodes) == 10 + assert session.env.game.simulation.network + assert len(session.env.game.simulation.network.nodes) == 10 @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_start_session(self, temp_primaite_session): From ed01293b862cb28fc7d13b9d04e994d98ca663cb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 1 Mar 2024 16:02:27 +0000 Subject: [PATCH 07/19] Make db admin reward persistent --- src/primaite/game/agent/rewards.py | 8 +++++--- src/primaite/game/game.py | 2 +- .../simulator/system/applications/database_client.py | 9 ++++----- .../simulator/system/applications/web_browser.py | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 4eb1ab3f..882ad024 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -263,11 +263,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): :type state: Dict """ db_state = access_from_nested_dict(state, self.location_in_state) - if db_state is NOT_PRESENT_IN_STATE or "connections_status" not in db_state: + if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") - connections_status = db_state["connections_status"] - if False in connections_status: + last_connection_successful = db_state["last_connection_successful"] + if last_connection_successful is False: return -1.0 + elif last_connection_successful is True: + return 1.0 return 0 @classmethod diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b9f92d3a..cf21dd40 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -296,7 +296,7 @@ class PrimaiteGame: if service_type == "DatabaseService": if "options" in service_cfg: opt = service_cfg["options"] - new_service.password = opt.get("backup_server_ip", None) + new_service.password = opt.get("db_password", None) new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) if service_type == "FTPServer": if "options" in service_cfg: diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index fe8180d7..addad35a 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -1,5 +1,5 @@ from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from uuid import uuid4 from primaite import getLogger @@ -26,7 +26,7 @@ class DatabaseClient(Application): server_password: Optional[str] = None connected: bool = False _query_success_tracker: Dict[str, bool] = {} - _connections_status: List[bool] = [] + _last_connection_successful: Optional[bool] = None """Keep track of connections that were established or verified during this step. Used for rewards.""" def __init__(self, **kwargs): @@ -46,7 +46,7 @@ class DatabaseClient(Application): can_connect = self.connect(connection_id=list(self.connections.keys())[-1]) else: can_connect = self.connect() - self._connections_status.append(can_connect) + self._last_connection_successful = can_connect return can_connect def describe_state(self) -> Dict: @@ -57,8 +57,7 @@ class DatabaseClient(Application): """ state = super().describe_state() # list of connections that were established or verified during this step. - state["connections_status"] = [c for c in self._connections_status] - self._connections_status.clear() + state["last_connection_successful"] = self._last_connection_successful return state def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 6f2c479c..9fa86328 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -199,7 +199,7 @@ class WebBrowser(Application): def state(self) -> Dict: """Return the contents of this dataclass as a dict for use with describe_state method.""" if self.status == self._HistoryItemStatus.LOADED: - outcome = self.response_code + outcome = self.response_code.value else: outcome = self.status.value return {"url": self.url, "outcome": outcome} From 2a1d99cccee0c49360f84fc0640240c77052c65f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 1 Mar 2024 16:36:41 +0000 Subject: [PATCH 08/19] Fix problem with checking connection for db admin --- .../system/applications/database_client.py | 14 ++++++++------ .../applications/red_applications/dos_bot.py | 2 +- .../system/services/database/database_service.py | 12 ++++++++++++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index addad35a..69065225 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -43,9 +43,9 @@ class DatabaseClient(Application): def execute(self) -> bool: """Execution definition for db client: perform a select query.""" if self.connections: - can_connect = self.connect(connection_id=list(self.connections.keys())[-1]) + can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1]) else: - can_connect = self.connect() + can_connect = self.check_connection(connection_id=str(uuid4())) self._last_connection_successful = can_connect return can_connect @@ -79,15 +79,17 @@ class DatabaseClient(Application): if not connection_id: connection_id = str(uuid4()) - # if we are reusing a connection_id, remove it from self.connections so that its new status can be populated - # warning: janky - self._connections.pop(connection_id, None) - self.connected = self._connect( server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id ) return self.connected + def check_connection(self, connection_id:str) -> bool: + if not self._can_perform_action(): + return False + print(self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)) + return self.connected + def _connect( self, server_ip_address: IPv4Address, diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 9dac6b25..1247bc99 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -28,7 +28,7 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient, Application): +class DoSBot(DatabaseClient): """A bot that simulates a Denial of Service attack.""" target_ip_address: Optional[IPv4Address] = None diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 9fdfd5ff..c73132eb 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -221,6 +221,18 @@ class DatabaseService(Service): } else: return {"status_code": 404, "data": False} + elif query == "SELECT * FROM pg_stat_activity": + # Check if the connection is active. + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 401, "data": False} else: # Invalid query self.sys_log.info(f"{self.name}: Invalid {query}") From 80158fd9b4e1b2beeff3c42843c1f50b0dbc6716 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 11:18:06 +0000 Subject: [PATCH 09/19] Make db manipulation bot work with db client --- src/primaite/game/game.py | 6 +-- .../network/transmission/network_layer.py | 2 + .../system/applications/database_client.py | 9 +++- .../red_applications/data_manipulation_bot.py | 48 +++++++++++++++---- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cf21dd40..10c02b39 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -409,9 +409,6 @@ class PrimaiteGame: # CREATE REWARD FUNCTION reward_function = RewardFunction.from_config(reward_function_cfg) - # OTHER AGENT SETTINGS - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - # CREATE AGENT if agent_type == "probabilistic_agent": # TODO: implement non-random agents and fix this parsing @@ -424,6 +421,7 @@ class PrimaiteGame: settings=settings, ) elif agent_type == "ProxyAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -433,6 +431,8 @@ class PrimaiteGame: ) game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) + new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], action_space=action_space, diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index dc848ade..22d7f97d 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -15,6 +15,8 @@ class IPProtocol(Enum): .. _List of IPProtocols: """ + NONE = "none" + """Placeholder for a non-port.""" TCP = "tcp" """Transmission Control Protocol.""" UDP = "udp" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 69065225..a8eac196 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -84,7 +84,14 @@ class DatabaseClient(Application): ) return self.connected - def check_connection(self, connection_id:str) -> bool: + def check_connection(self, connection_id: str) -> bool: + """Check whether the connection can be successfully re-established. + + :param connection_id: connection ID to check + :type connection_id: str + :return: Whether the connection was successfully re-established. + :rtype: bool + """ if not self._can_perform_action(): return False print(self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)) diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 5fe951b7..11eb71f5 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -1,10 +1,13 @@ from enum import IntEnum from ipaddress import IPv4Address -from typing import Optional +from typing import Dict, Optional from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -32,12 +35,12 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(DatabaseClient): +class DataManipulationBot(Application): """A bot that simulates a script which performs a SQL injection attack.""" - server_ip_address: Optional[IPv4Address] = None + # server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None - server_password: Optional[str] = None + # server_password: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 @@ -46,8 +49,31 @@ class DataManipulationBot(DatabaseClient): "Whether to repeat attacking once finished." def __init__(self, **kwargs): + kwargs["name"] = "DataManipulationBot" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) - self.name = "DataManipulationBot" + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + state = super().describe_state() + return state + + @property + def _host_db_client(self) -> DatabaseClient: + """Return the database client that is installed on the same machine as the DataManipulationBot.""" + db_client = self.software_manager.software.get("DatabaseClient") + if db_client is None: + _LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.") + return db_client def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -76,8 +102,8 @@ class DataManipulationBot(DatabaseClient): :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address - self.payload = payload self.server_password = server_password + self.payload = payload self.port_scan_p_of_success = port_scan_p_of_success self.data_manipulation_p_of_success = data_manipulation_p_of_success self.repeat = repeat @@ -123,15 +149,17 @@ class DataManipulationBot(DatabaseClient): :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. """ + self._host_db_client.server_ip_address = self.server_ip_address + self._host_db_client.server_password = self.server_password if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack - if not len(self.connections): - self.connect() - if len(self.connections): - self.query(self.payload) + if not len(self._host_db_client.connections): + self._host_db_client.connect() + if len(self._host_db_client.connections): + self._host_db_client.query(self.payload) self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True if attack_successful: From 4a292a6239d748baf224ef8a52519a1eed6a0f3c Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 11:23:24 +0000 Subject: [PATCH 10/19] Fix checking connection in db client --- src/primaite/simulator/system/applications/database_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index a8eac196..7b259ff4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -94,8 +94,7 @@ class DatabaseClient(Application): """ if not self._can_perform_action(): return False - print(self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)) - return self.connected + return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id) def _connect( self, From 9762927289568d7b68da5214a436e68ebd4b61c4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 11:43:24 +0000 Subject: [PATCH 11/19] Update notebook with new changes --- src/primaite/notebooks/uc2_demo.ipynb | 620 ++++++++++++++++++++++++-- 1 file changed, 581 insertions(+), 39 deletions(-) diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 13fb7d80..36942b73 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -13,7 +13,7 @@ "source": [ "## Scenario\n", "\n", - "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n", + "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database. Occasionally, admins need to access the database directly from the clients.\n", "\n", "[](_package_data/uc2_network.png)\n", "\n", @@ -46,7 +46,9 @@ "source": [ "## Green agent\n", "\n", - "There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available." + "There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available.\n", + "\n", + "Sometimes, the green agents send a request directly to the database to check that it is reachable." ] }, { @@ -68,7 +70,9 @@ "source": [ "## Blue agent\n", "\n", - "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router." + "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router.\n", + "\n", + "However, these rules will also impact greens' ability to check the database connection. The blue agent should only block the infected client, it should let the other client connect freely. Once the attack has begun, automated traffic monitoring will detect it as suspicious network traffic. The blue agent's observation space will show this as an increase in the number of malicious network events (NMNE) on one of the network interfaces. To achieve optimal reward, the agent should only block the client which has the non-zero outbound NMNE." ] }, { @@ -101,9 +105,11 @@ "The red agent does not use information about the state of the network to decide its action.\n", "\n", "### Green\n", - "The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n", + "The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, it will do nothing 30% of the time, send a web request 60% of the time, and send a db status check 10% of the time.\n", "\n", - "When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender." + "When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender.\n", + "\n", + "Also, when the green agent is blocked from checking the database status, it causes a small negative reward." ] }, { @@ -322,9 +328,10 @@ "source": [ "## Reward Function\n", "\n", - "The blue agent's reward is calculated using two measures:\n", + "The blue agent's reward is calculated using these measures:\n", "1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n", "2. Whether each green agents' most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n", + "3. Whether each green agents' most recent DB status check was successful (+1 for a successful connection, -1 for no connection).\n", "\n", "The file status reward and the two green-agent-related rewards are averaged to get a total step reward.\n" ] @@ -345,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "tags": [] }, @@ -357,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "tags": [] }, @@ -382,9 +389,169 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resetting environment, episode 0, avg. reward: 0.0\n", + "env created successfully\n", + "{'ACL': {1: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 0,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 2: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 1,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 3: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 2,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 4: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 3,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 5: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 4,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 6: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 5,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 7: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 6,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 8: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 7,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 9: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 8,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0},\n", + " 10: {'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'permission': 0,\n", + " 'position': 9,\n", + " 'protocol': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0}},\n", + " 'ICS': 0,\n", + " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", + " 2: {'PROTOCOLS': {'ALL': 1}},\n", + " 3: {'PROTOCOLS': {'ALL': 1}},\n", + " 4: {'PROTOCOLS': {'ALL': 1}},\n", + " 5: {'PROTOCOLS': {'ALL': 1}},\n", + " 6: {'PROTOCOLS': {'ALL': 1}},\n", + " 7: {'PROTOCOLS': {'ALL': 1}},\n", + " 8: {'PROTOCOLS': {'ALL': 1}},\n", + " 9: {'PROTOCOLS': {'ALL': 1}},\n", + " 10: {'PROTOCOLS': {'ALL': 0}}},\n", + " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", + " 'health_status': 1}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", + " 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0,\n", + " 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}}\n" + ] + } + ], "source": [ "# create the env\n", "with open(example_config_path(), 'r') as f:\n", @@ -407,12 +574,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -1.0 when green agents try to access the webpage." + "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -0.8 when green agents try to access the webpage." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -430,9 +597,51 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 1, Red action: DO NOTHING, Blue reward:0.77\n", + "step: 2, Red action: DO NOTHING, Blue reward:0.77\n", + "step: 3, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 4, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 5, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 6, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 7, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 8, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 9, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 10, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 11, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 12, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 13, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 14, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 15, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 16, Red action: DO NOTHING, Blue reward:1.10\n", + "step: 17, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 18, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 19, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 20, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 21, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 22, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 23, Red action: DO NOTHING, Blue reward:1.20\n", + "step: 24, Red action: ATTACK from client 2, Blue reward:0.52\n", + "step: 25, Red action: DO NOTHING, Blue reward:0.52\n", + "step: 26, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 27, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 28, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 29, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 30, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 31, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 32, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 33, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 34, Red action: DO NOTHING, Blue reward:-0.80\n", + "step: 35, Red action: DO NOTHING, Blue reward:-0.80\n" + ] + } + ], "source": [ "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", @@ -448,9 +657,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}}, 'health_status': 1}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 1, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 1}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] + } + ], "source": [ "pprint(obs['NODES'])" ] @@ -464,9 +729,65 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 3, 'operating_status': 1}},\n", + " 'operating_status': 1},\n", + " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 2}}, 'health_status': 1}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 1, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1},\n", + " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", + " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", + " 'nmne': {'inbound': 0, 'outbound': 1}},\n", + " 2: {'nic_status': 0,\n", + " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", + " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", + " 'operating_status': 1}}\n" + ] + } + ], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", "obs, reward, terminated, truncated, info = env.step(1) # scan webapp service\n", @@ -481,6 +802,13 @@ "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -490,9 +818,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 38\n", + "Red action: DONOTHING\n", + "Green action: DONOTHING\n", + "Green action: DONOTHING\n", + "Blue reward:-0.8\n" + ] + } + ], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", @@ -515,16 +855,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 52\n", + "Red action: DONOTHING\n", + "Green action: ('NODE_APPLICATION_EXECUTE', {'node_id': 0, 'application_id': 0})\n", + "Green action: ('NODE_APPLICATION_EXECUTE', {'node_id': 0, 'application_id': 0})\n", + "Blue reward:-0.80\n" + ] + } + ], "source": [ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n", "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", - "print(f\"Blue reward:{reward}\" )" + "print(f\"Blue reward:{reward:.2f}\" )" ] }, { @@ -538,29 +890,69 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 53, Red action: DONOTHING, Blue reward:-0.80\n", + "step: 54, Red action: DONOTHING, Blue reward:-0.80\n", + "step: 55, Red action: DONOTHING, Blue reward:-0.80\n", + "step: 56, Red action: DONOTHING, Blue reward:0.54\n", + "step: 57, Red action: DONOTHING, Blue reward:1.20\n", + "step: 58, Red action: DONOTHING, Blue reward:1.20\n", + "step: 59, Red action: DONOTHING, Blue reward:1.20\n", + "step: 60, Red action: DONOTHING, Blue reward:1.20\n", + "step: 61, Red action: DONOTHING, Blue reward:1.20\n", + "step: 62, Red action: DONOTHING, Blue reward:1.20\n", + "step: 63, Red action: DONOTHING, Blue reward:1.20\n", + "step: 64, Red action: DONOTHING, Blue reward:1.20\n", + "step: 65, Red action: DONOTHING, Blue reward:1.00\n", + "step: 66, Red action: DONOTHING, Blue reward:1.00\n", + "step: 67, Red action: DONOTHING, Blue reward:1.00\n", + "step: 68, Red action: DONOTHING, Blue reward:1.00\n", + "step: 69, Red action: DONOTHING, Blue reward:1.00\n", + "step: 70, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.00\n", + "step: 71, Red action: DONOTHING, Blue reward:1.00\n", + "step: 72, Red action: DONOTHING, Blue reward:1.00\n", + "step: 73, Red action: DONOTHING, Blue reward:1.00\n", + "step: 74, Red action: DONOTHING, Blue reward:1.00\n", + "step: 75, Red action: DONOTHING, Blue reward:0.80\n", + "step: 76, Red action: DONOTHING, Blue reward:0.80\n", + "step: 77, Red action: DONOTHING, Blue reward:0.80\n", + "step: 78, Red action: DONOTHING, Blue reward:0.80\n", + "step: 79, Red action: DONOTHING, Blue reward:0.80\n", + "step: 80, Red action: DONOTHING, Blue reward:0.80\n", + "step: 81, Red action: DONOTHING, Blue reward:0.80\n", + "step: 82, Red action: DONOTHING, Blue reward:0.80\n", + "step: 83, Red action: DONOTHING, Blue reward:0.80\n", + "step: 84, Red action: DONOTHING, Blue reward:0.80\n", + "step: 85, Red action: NODE_APPLICATION_EXECUTE, Blue reward:0.80\n" + ] + } + ], "source": [ "env.step(13) # Patch the database\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(26) # Block client 1\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(27) # Block client 2\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", "\n", "for step in range(30):\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now, even though the red agent executes an attack, the reward stays at 1.0" + "Now, even though the red agent executes an attack, the reward stays at 0.8." ] }, { @@ -572,11 +964,168 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{1: {'position': 0,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 2: {'position': 1,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 3: {'position': 2,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 4: {'position': 3,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 5: {'position': 4,\n", + " 'permission': 2,\n", + " 'source_node_id': 7,\n", + " 'source_port': 1,\n", + " 'dest_node_id': 4,\n", + " 'dest_port': 1,\n", + " 'protocol': 3},\n", + " 6: {'position': 5,\n", + " 'permission': 2,\n", + " 'source_node_id': 8,\n", + " 'source_port': 1,\n", + " 'dest_node_id': 4,\n", + " 'dest_port': 1,\n", + " 'protocol': 3},\n", + " 7: {'position': 6,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 8: {'position': 7,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 9: {'position': 8,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0},\n", + " 10: {'position': 9,\n", + " 'permission': 0,\n", + " 'source_node_id': 0,\n", + " 'source_port': 0,\n", + " 'dest_node_id': 0,\n", + " 'dest_port': 0,\n", + " 'protocol': 0}}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obs['ACL']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can slightly increase the reward by unblocking the client which isn't being used by the attacker. If node 6 has outbound NMNEs, let's unblock client 2, and if node 7 has outbound NMNEs, let's unblock client 1." + ] + }, + { + "cell_type": "code", + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ - "obs['ACL']" + "if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", + " # client 1 has NMNEs, let's unblock client 2\n", + " env.step(34) # remove ACL rule 6\n", + "elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", + " env.step(33) # remove ACL rule 5\n", + "else:\n", + " print(\"something went wrong, neither client has NMNEs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 117, Red action: DONOTHING, Blue reward:1.00\n", + "step: 118, Red action: DONOTHING, Blue reward:1.00\n", + "step: 119, Red action: DONOTHING, Blue reward:1.00\n", + "step: 120, Red action: DONOTHING, Blue reward:1.00\n", + "step: 121, Red action: DONOTHING, Blue reward:1.00\n", + "step: 122, Red action: DONOTHING, Blue reward:1.00\n", + "step: 123, Red action: DONOTHING, Blue reward:1.00\n", + "step: 124, Red action: DONOTHING, Blue reward:1.00\n", + "step: 125, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.00\n", + "step: 126, Red action: DONOTHING, Blue reward:1.00\n", + "step: 127, Red action: DONOTHING, Blue reward:1.00\n", + "step: 128, Red action: DONOTHING, Blue reward:1.00\n", + "step: 129, Red action: DONOTHING, Blue reward:1.00\n", + "step: 130, Red action: DONOTHING, Blue reward:1.00\n", + "step: 131, Red action: DONOTHING, Blue reward:1.00\n", + "step: 132, Red action: DONOTHING, Blue reward:1.00\n", + "step: 133, Red action: DONOTHING, Blue reward:1.00\n", + "step: 134, Red action: DONOTHING, Blue reward:1.00\n", + "step: 135, Red action: DONOTHING, Blue reward:1.00\n", + "step: 136, Red action: DONOTHING, Blue reward:1.00\n", + "step: 137, Red action: DONOTHING, Blue reward:1.00\n", + "step: 138, Red action: DONOTHING, Blue reward:1.00\n", + "step: 139, Red action: DONOTHING, Blue reward:1.00\n", + "step: 140, Red action: DONOTHING, Blue reward:1.00\n", + "step: 141, Red action: DONOTHING, Blue reward:1.00\n", + "step: 142, Red action: DONOTHING, Blue reward:1.00\n", + "step: 143, Red action: DONOTHING, Blue reward:1.00\n", + "step: 144, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.00\n", + "step: 145, Red action: DONOTHING, Blue reward:1.00\n", + "step: 146, Red action: DONOTHING, Blue reward:1.00\n" + ] + } + ], + "source": [ + "for step in range(30):\n", + " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )" ] }, { @@ -594,13 +1143,6 @@ "source": [ "env.reset()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 070655cfce3da94db274377d404beb5d86955a8b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 11:47:50 +0000 Subject: [PATCH 12/19] Update data manipulation bot documentation --- .../system/applications/data_manipulation_bot.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 304621dd..9188733b 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -45,7 +45,7 @@ In a simulation, the bot can be controlled by using ``DataManipulationAgent`` wh Implementation ============== -The bot extends :ref:`DatabaseClient` and leverages its connectivity. +The bot connects to a :ref:`DatabaseClient` and leverages its connectivity. The host running ``DataManipulationBot`` must also have a :ref:`DatabaseClient` installed on it. - Uses the Application base class for lifecycle management. - Credentials, target IP and other options set via ``configure``. @@ -65,6 +65,7 @@ Python from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot + from primaite.simulator.system.applications.database_client import DatabaseClient client_1 = Computer( hostname="client_1", @@ -74,6 +75,7 @@ Python operating_state=NodeOperatingState.ON # initialise the computer in an ON state ) network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) + client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") @@ -148,6 +150,10 @@ If not using the data manipulation bot manually, it needs to be used with a data data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 + - ref: web_server_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 Configuration ============= From 4d51b1a4146bb861f80101f4b013df093395a1b2 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 14:57:28 +0000 Subject: [PATCH 13/19] Update configs to new db manipulation bot approach --- src/primaite/simulator/network/networks.py | 63 ++----------------- .../red_applications/data_manipulation_bot.py | 4 ++ tests/assets/configs/basic_firewall.yaml | 2 +- .../test_data_manipulation_bot.py | 15 ++++- 4 files changed, 23 insertions(+), 61 deletions(-) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index f82dee4a..fa9d86ef 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -146,6 +146,9 @@ def arcd_uc2_network() -> Network: ) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) + db_client_1 = client_1.software_manager.install(DatabaseClient) + db_client_1 = client_1.software_manager.software.get("DatabaseClient") + db_client_1.run() client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") db_manipulation_bot.configure( @@ -165,6 +168,9 @@ def arcd_uc2_network() -> Network: start_up_duration=0, ) client_2.power_on() + client_2.software_manager.install(DatabaseClient) + db_client_2 = client_2.software_manager.software.get("DatabaseClient") + db_client_2.run() web_browser = client_2.software_manager.software.get("WebBrowser") web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) @@ -194,67 +200,10 @@ def arcd_uc2_network() -> Network: database_server.power_on() network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3]) - ddl = """ - CREATE TABLE IF NOT EXISTS user ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name VARCHAR(50) NOT NULL, - email VARCHAR(50) NOT NULL, - age INT, - city VARCHAR(50), - occupation VARCHAR(50) - );""" - - user_insert_statements = [ - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", - # noqa - ] database_server.software_manager.install(DatabaseService) database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) - database_service._process_sql(ddl, None, None) # noqa - for insert_statement in user_insert_statements: - database_service._process_sql(insert_statement, None, None) # noqa # Web Server web_server = Server( diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 11eb71f5..961f82c2 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -149,6 +149,10 @@ class DataManipulationBot(Application): :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. """ + if self._host_db_client is None: + self.attack_stage = DataManipulationAttackStage.FAILED + return + self._host_db_client.server_ip_address = self.server_ip_address self._host_db_client.server_password = self.server_password if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml index 71dc31a7..0a892650 100644 --- a/tests/assets/configs/basic_firewall.yaml +++ b/tests/assets/configs/basic_firewall.yaml @@ -40,7 +40,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 2ca67119..6d00886a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -26,8 +26,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.POSTGRES_SERVER - assert data_manipulation_bot.protocol == IPProtocol.TCP + assert data_manipulation_bot.port == Port.NONE + assert data_manipulation_bot.protocol == IPProtocol.NONE assert data_manipulation_bot.payload == "DELETE" @@ -70,4 +70,13 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot._perform_data_manipulation(p_of_success=1.0) assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) - assert len(dm_bot.connections) + assert len(dm_bot._host_db_client.connections) + + +def test_dm_bot_fails_without_db_client(dm_client): + dm_client.software_manager.uninstall("DatabaseClient") + dm_bot = dm_client.software_manager.software.get("DataManipulationBot") + assert dm_bot._host_db_client is None + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + dm_bot._perform_data_manipulation(p_of_success=1.0) + assert dm_bot.attack_stage is DataManipulationAttackStage.FAILED From afa775baff03a7754ff0818934563f9587a2a1cb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 15:52:34 +0000 Subject: [PATCH 14/19] Add test for new reward --- .../game_layer/test_rewards.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index fd8a89a4..53753967 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,7 +1,10 @@ -from primaite.game.agent.rewards import WebpageUnavailablePenalty +from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty +from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService from tests.conftest import ControlledAgent @@ -35,3 +38,44 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent.store_action(action) game.step() assert agent.reward_function.current_reward == -0.7 + + +def test_uc2_rewards(game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + server_1.software_manager.install(DatabaseService) + db_service = server_1.software_manager.software.get("DatabaseService") + db_service.start() + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) + db_client.run() + + router: Router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + + comp = GreenAdminDatabaseUnreachablePenalty("client_1") + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate(state) + assert reward_value == 1.0 + + router.acl.remove_rule(position=2) + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate(state) + assert reward_value == -1.0 From ef1a2dc3f4635db875d7970ba859dce7bbc021df Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 16:00:10 +0000 Subject: [PATCH 15/19] clear uc2 notebook outputs --- src/primaite/notebooks/uc2_demo.ipynb | 539 ++------------------------ 1 file changed, 22 insertions(+), 517 deletions(-) diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index 36942b73..94be8baa 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -352,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "tags": [] }, @@ -364,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "tags": [] }, @@ -389,169 +389,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Resetting environment, episode 0, avg. reward: 0.0\n", - "env created successfully\n", - "{'ACL': {1: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 0,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 2: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 1,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 3: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 2,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 4: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 3,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 5: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 4,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 6: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 5,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 7: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 6,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 8: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 7,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 9: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 8,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0},\n", - " 10: {'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'permission': 0,\n", - " 'position': 9,\n", - " 'protocol': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0}},\n", - " 'ICS': 0,\n", - " 'LINKS': {1: {'PROTOCOLS': {'ALL': 1}},\n", - " 2: {'PROTOCOLS': {'ALL': 1}},\n", - " 3: {'PROTOCOLS': {'ALL': 1}},\n", - " 4: {'PROTOCOLS': {'ALL': 1}},\n", - " 5: {'PROTOCOLS': {'ALL': 1}},\n", - " 6: {'PROTOCOLS': {'ALL': 1}},\n", - " 7: {'PROTOCOLS': {'ALL': 1}},\n", - " 8: {'PROTOCOLS': {'ALL': 1}},\n", - " 9: {'PROTOCOLS': {'ALL': 1}},\n", - " 10: {'PROTOCOLS': {'ALL': 0}}},\n", - " 'NODES': {1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}},\n", - " 'health_status': 1}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}},\n", - " 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0,\n", - " 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1}}}\n" - ] - } - ], + "outputs": [], "source": [ "# create the env\n", "with open(example_config_path(), 'r') as f:\n", @@ -579,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -597,51 +437,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 1, Red action: DO NOTHING, Blue reward:0.77\n", - "step: 2, Red action: DO NOTHING, Blue reward:0.77\n", - "step: 3, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 4, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 5, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 6, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 7, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 8, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 9, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 10, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 11, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 12, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 13, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 14, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 15, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 16, Red action: DO NOTHING, Blue reward:1.10\n", - "step: 17, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 18, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 19, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 20, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 21, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 22, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 23, Red action: DO NOTHING, Blue reward:1.20\n", - "step: 24, Red action: ATTACK from client 2, Blue reward:0.52\n", - "step: 25, Red action: DO NOTHING, Blue reward:0.52\n", - "step: 26, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 27, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 28, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 29, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 30, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 31, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 32, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 33, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 34, Red action: DO NOTHING, Blue reward:-0.80\n", - "step: 35, Red action: DO NOTHING, Blue reward:-0.80\n" - ] - } - ], + "outputs": [], "source": [ "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", @@ -657,65 +455,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 1}}, 'health_status': 1}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 1, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 1}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1}}\n" - ] - } - ], + "outputs": [], "source": [ "pprint(obs['NODES'])" ] @@ -729,65 +471,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{1: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 2: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 3, 'operating_status': 1}},\n", - " 'operating_status': 1},\n", - " 3: {'FOLDERS': {1: {'FILES': {1: {'health_status': 2}}, 'health_status': 1}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 1, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 4: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 5: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 6: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1},\n", - " 7: {'FOLDERS': {1: {'FILES': {1: {'health_status': 0}}, 'health_status': 0}},\n", - " 'NETWORK_INTERFACES': {1: {'nic_status': 1,\n", - " 'nmne': {'inbound': 0, 'outbound': 1}},\n", - " 2: {'nic_status': 0,\n", - " 'nmne': {'inbound': 0, 'outbound': 0}}},\n", - " 'SERVICES': {1: {'health_status': 0, 'operating_status': 0}},\n", - " 'operating_status': 1}}\n" - ] - } - ], + "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", "obs, reward, terminated, truncated, info = env.step(1) # scan webapp service\n", @@ -818,21 +504,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 38\n", - "Red action: DONOTHING\n", - "Green action: DONOTHING\n", - "Green action: DONOTHING\n", - "Blue reward:-0.8\n" - ] - } - ], + "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", @@ -855,21 +529,9 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 52\n", - "Red action: DONOTHING\n", - "Green action: ('NODE_APPLICATION_EXECUTE', {'node_id': 0, 'application_id': 0})\n", - "Green action: ('NODE_APPLICATION_EXECUTE', {'node_id': 0, 'application_id': 0})\n", - "Blue reward:-0.80\n" - ] - } - ], + "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", @@ -890,49 +552,9 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 53, Red action: DONOTHING, Blue reward:-0.80\n", - "step: 54, Red action: DONOTHING, Blue reward:-0.80\n", - "step: 55, Red action: DONOTHING, Blue reward:-0.80\n", - "step: 56, Red action: DONOTHING, Blue reward:0.54\n", - "step: 57, Red action: DONOTHING, Blue reward:1.20\n", - "step: 58, Red action: DONOTHING, Blue reward:1.20\n", - "step: 59, Red action: DONOTHING, Blue reward:1.20\n", - "step: 60, Red action: DONOTHING, Blue reward:1.20\n", - "step: 61, Red action: DONOTHING, Blue reward:1.20\n", - "step: 62, Red action: DONOTHING, Blue reward:1.20\n", - "step: 63, Red action: DONOTHING, Blue reward:1.20\n", - "step: 64, Red action: DONOTHING, Blue reward:1.20\n", - "step: 65, Red action: DONOTHING, Blue reward:1.00\n", - "step: 66, Red action: DONOTHING, Blue reward:1.00\n", - "step: 67, Red action: DONOTHING, Blue reward:1.00\n", - "step: 68, Red action: DONOTHING, Blue reward:1.00\n", - "step: 69, Red action: DONOTHING, Blue reward:1.00\n", - "step: 70, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.00\n", - "step: 71, Red action: DONOTHING, Blue reward:1.00\n", - "step: 72, Red action: DONOTHING, Blue reward:1.00\n", - "step: 73, Red action: DONOTHING, Blue reward:1.00\n", - "step: 74, Red action: DONOTHING, Blue reward:1.00\n", - "step: 75, Red action: DONOTHING, Blue reward:0.80\n", - "step: 76, Red action: DONOTHING, Blue reward:0.80\n", - "step: 77, Red action: DONOTHING, Blue reward:0.80\n", - "step: 78, Red action: DONOTHING, Blue reward:0.80\n", - "step: 79, Red action: DONOTHING, Blue reward:0.80\n", - "step: 80, Red action: DONOTHING, Blue reward:0.80\n", - "step: 81, Red action: DONOTHING, Blue reward:0.80\n", - "step: 82, Red action: DONOTHING, Blue reward:0.80\n", - "step: 83, Red action: DONOTHING, Blue reward:0.80\n", - "step: 84, Red action: DONOTHING, Blue reward:0.80\n", - "step: 85, Red action: NODE_APPLICATION_EXECUTE, Blue reward:0.80\n" - ] - } - ], + "outputs": [], "source": [ "env.step(13) # Patch the database\n", "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", @@ -964,89 +586,9 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{1: {'position': 0,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 2: {'position': 1,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 3: {'position': 2,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 4: {'position': 3,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 5: {'position': 4,\n", - " 'permission': 2,\n", - " 'source_node_id': 7,\n", - " 'source_port': 1,\n", - " 'dest_node_id': 4,\n", - " 'dest_port': 1,\n", - " 'protocol': 3},\n", - " 6: {'position': 5,\n", - " 'permission': 2,\n", - " 'source_node_id': 8,\n", - " 'source_port': 1,\n", - " 'dest_node_id': 4,\n", - " 'dest_port': 1,\n", - " 'protocol': 3},\n", - " 7: {'position': 6,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 8: {'position': 7,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 9: {'position': 8,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0},\n", - " 10: {'position': 9,\n", - " 'permission': 0,\n", - " 'source_node_id': 0,\n", - " 'source_port': 0,\n", - " 'dest_node_id': 0,\n", - " 'dest_port': 0,\n", - " 'protocol': 0}}" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "obs['ACL']" ] @@ -1060,7 +602,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1082,46 +624,9 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step: 117, Red action: DONOTHING, Blue reward:1.00\n", - "step: 118, Red action: DONOTHING, Blue reward:1.00\n", - "step: 119, Red action: DONOTHING, Blue reward:1.00\n", - "step: 120, Red action: DONOTHING, Blue reward:1.00\n", - "step: 121, Red action: DONOTHING, Blue reward:1.00\n", - "step: 122, Red action: DONOTHING, Blue reward:1.00\n", - "step: 123, Red action: DONOTHING, Blue reward:1.00\n", - "step: 124, Red action: DONOTHING, Blue reward:1.00\n", - "step: 125, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.00\n", - "step: 126, Red action: DONOTHING, Blue reward:1.00\n", - "step: 127, Red action: DONOTHING, Blue reward:1.00\n", - "step: 128, Red action: DONOTHING, Blue reward:1.00\n", - "step: 129, Red action: DONOTHING, Blue reward:1.00\n", - "step: 130, Red action: DONOTHING, Blue reward:1.00\n", - "step: 131, Red action: DONOTHING, Blue reward:1.00\n", - "step: 132, Red action: DONOTHING, Blue reward:1.00\n", - "step: 133, Red action: DONOTHING, Blue reward:1.00\n", - "step: 134, Red action: DONOTHING, Blue reward:1.00\n", - "step: 135, Red action: DONOTHING, Blue reward:1.00\n", - "step: 136, Red action: DONOTHING, Blue reward:1.00\n", - "step: 137, Red action: DONOTHING, Blue reward:1.00\n", - "step: 138, Red action: DONOTHING, Blue reward:1.00\n", - "step: 139, Red action: DONOTHING, Blue reward:1.00\n", - "step: 140, Red action: DONOTHING, Blue reward:1.00\n", - "step: 141, Red action: DONOTHING, Blue reward:1.00\n", - "step: 142, Red action: DONOTHING, Blue reward:1.00\n", - "step: 143, Red action: DONOTHING, Blue reward:1.00\n", - "step: 144, Red action: NODE_APPLICATION_EXECUTE, Blue reward:1.00\n", - "step: 145, Red action: DONOTHING, Blue reward:1.00\n", - "step: 146, Red action: DONOTHING, Blue reward:1.00\n" - ] - } - ], + "outputs": [], "source": [ "for step in range(30):\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", From a6031d568d175a56a572f7e321cd29da2aad04a5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 3 Mar 2024 16:36:08 +0000 Subject: [PATCH 16/19] Remove unused import --- .../simulator/system/applications/red_applications/dos_bot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 1247bc99..202fd189 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -6,7 +6,6 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) From d1480e4477f42f46a99776711f80c3a71ee90934 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 4 Mar 2024 09:58:57 +0000 Subject: [PATCH 17/19] Apply suggestions from PR review. --- docs/source/configuration/agents.rst | 6 +++--- .../system/services/database_service.rst | 7 +++++++ src/primaite/config/_package_data/example_config.yaml | 4 ++-- .../config/_package_data/example_config_2_rl_agents.yaml | 4 ++-- src/primaite/game/game.py | 2 +- .../simulator/network/transmission/network_layer.py | 2 +- .../applications/red_applications/data_manipulation_bot.py | 2 -- tests/assets/configs/bad_primaite_session.yaml | 2 +- tests/assets/configs/basic_firewall.yaml | 2 +- tests/assets/configs/basic_switched_network.yaml | 2 +- tests/assets/configs/dmz_network.yaml | 2 +- tests/assets/configs/eval_only_primaite_session.yaml | 2 +- tests/assets/configs/multi_agent_session.yaml | 2 +- tests/assets/configs/test_primaite_session.yaml | 2 +- tests/assets/configs/train_only_primaite_session.yaml | 2 +- tests/integration_tests/game_layer/test_rewards.py | 1 + 16 files changed, 25 insertions(+), 19 deletions(-) diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index ac67c365..b8912883 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo ... - ref: green_agent_example team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive ``type`` -------- -Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``probabilistic_agent`` generate their own behaviour. +Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour. Available agent types: -- ``probabilistic_agent`` +- ``ProbabilisticAgent`` - ``ProxyAgent`` - ``RedDatabaseCorruptingAgent`` diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst index 2c962c0a..dd6dec41 100644 --- a/docs/source/simulation_components/system/services/database_service.rst +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -25,6 +25,13 @@ Usage - Clients connect, execute queries, and disconnect. - Service runs on TCP port 5432 by default. +**Supported queries:** + +* ``SELECT``: As long as the database file is in a ``GOOD`` health state, the db service will respond with a 200 status code. +* ``DELETE``: This query represents an attack, it will cause the database file to enter a ``COMPROMISED`` state, and return a status code 200. +* ``INSERT``: If the database service is in a healthy state, this will return a 200 status, if it's not in a healthy state it will return 404. +* ``SELECT * FROM pg_stat_activity``: This query represents something an admin would send to check the status of the database. If the database service is in a healthy state, it returns a 200 status code, otherwise a 401 status code. + Implementation ============== diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 45d29b48..8d1b4293 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -33,7 +33,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent agent_settings: action_probabilities: 0: 0.3 @@ -76,7 +76,7 @@ agents: - ref: client_1_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent agent_settings: action_probabilities: 0: 0.3 diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index b6b07afa..260517b9 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -35,7 +35,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -64,7 +64,7 @@ agents: - ref: client_1_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index bfbefd3c..0749e5db 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -414,7 +414,7 @@ class PrimaiteGame: reward_function = RewardFunction.from_config(reward_function_cfg) # CREATE AGENT - if agent_type == "probabilistic_agent": + if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing settings = agent_cfg.get("agent_settings") new_agent = ProbabilisticAgent( diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index 22d7f97d..8ee0b4af 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -16,7 +16,7 @@ class IPProtocol(Enum): """ NONE = "none" - """Placeholder for a non-port.""" + """Placeholder for a non-protocol.""" TCP = "tcp" """Transmission Control Protocol.""" UDP = "udp" diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 961f82c2..ee98ea8e 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -38,9 +38,7 @@ class DataManipulationAttackStage(IntEnum): class DataManipulationBot(Application): """A bot that simulates a script which performs a SQL injection attack.""" - # server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None - # server_password: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 017492ad..38d54ce3 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -21,7 +21,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml index 0a892650..9d7b34cb 100644 --- a/tests/assets/configs/basic_firewall.yaml +++ b/tests/assets/configs/basic_firewall.yaml @@ -40,7 +40,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 6c6b2845..9a0d5313 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -40,7 +40,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 56a68410..95e09e16 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -65,7 +65,7 @@ game: agents: - ref: client_1_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index e70814f5..f2815578 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 6401bcda..8bbddb76 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -31,7 +31,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index c2616001..199cf8cc 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -29,7 +29,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 8ef4b8fd..71a23989 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: probabilistic_agent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 53753967..8edbf0ac 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -41,6 +41,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): def test_uc2_rewards(game_and_agent): + """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" game, agent = game_and_agent agent: ControlledAgent From ac9d550e9b2f3ff48a5f93f5612f34395dba9a6d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 4 Mar 2024 10:43:38 +0000 Subject: [PATCH 18/19] Change get_action signature for agents --- .../game/agent/data_manipulation_bot.py | 14 ++++++------- src/primaite/game/agent/interface.py | 20 +++++++++---------- src/primaite/game/agent/rewards.py | 2 +- src/primaite/game/agent/scripted_agents.py | 12 +++++------ src/primaite/game/game.py | 3 +-- tests/conftest.py | 2 +- .../_game/_agent/test_probabilistic_agent.py | 2 +- 7 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index b5de9a5a..c758c926 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -1,5 +1,5 @@ import random -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple from gymnasium.core import ObsType @@ -26,14 +26,14 @@ class DataManipulationAgent(AbstractScriptedAgent): ) self.next_execution_timestep = timestep + random_timestep_increment - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. + def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: + """Waits until a specific timestep, then attempts to execute its data manipulation application. - :param obs: _description_ + :param obs: Current observation for this agent, not used in DataManipulationAgent :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ + :param timestep: The current simulation timestep, used for scheduling actions + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 4f434bad..88848479 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -112,7 +112,7 @@ class AbstractAgent(ABC): return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -152,14 +152,14 @@ class AbstractScriptedAgent(AbstractAgent): class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Sample the action space randomly. - :param obs: _description_ + :param obs: Current observation for this agent, not used in RandomAgent :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ + :param timestep: The current simulation timestep, not used in RandomAgent + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ return self.action_manager.get_action(self.action_manager.space.sample()) @@ -185,14 +185,14 @@ class ProxyAgent(AbstractAgent): self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. :type obs: ObsType - :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. - :type reward: float, optional + :param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface. + :type timestep: int :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 882ad024..8c8e36ad 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -270,7 +270,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return -1.0 elif last_connection_successful is True: return 1.0 - return 0 + return 0.0 @classmethod def from_config(cls, config: Dict) -> AbstractReward: diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py index 28d94062..5111df32 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents.py @@ -70,17 +70,17 @@ class ProbabilisticAgent(AbstractScriptedAgent): super().__init__(agent_name, action_space, observation_space, reward_function) - def get_action(self, obs: ObsType, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ - Choose a random action from the action space. + Sample the action space randomly. The probability of each action is given by the corresponding index in ``self.probabilities``. - :param obs: Current observation of the simulation + :param obs: Current observation for this agent, not used in ProbabilisticAgent :type obs: ObsType - :param reward: Reward for the last step, not used for scripted agents, defaults to 0 - :type reward: float, optional - :return: Action to be taken in CAOS format. + :param timestep: The current simulation timestep, not used in ProbabilisticAgent + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 0749e5db..cd88d832 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -165,8 +165,7 @@ class PrimaiteGame: agent_actions = {} for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation - rew = agent.reward_function.current_reward - action_choice, options = agent.get_action(obs, rew, timestep=self.step_counter) + action_choice, options = agent.get_action(obs, timestep=self.step_counter) agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) diff --git a/tests/conftest.py b/tests/conftest.py index b60de730..a117a1ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent): ) self.most_recent_action: Tuple[str, Dict] - def get_action(self, obs: None, reward: float = 0.0, timestep: Optional[int] = None) -> Tuple[str, Dict]: + def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" return self.most_recent_action diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index f0b37cac..73228e36 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -69,7 +69,7 @@ def test_probabilistic_agent(): node_application_execute_count = 0 node_file_delete_count = 0 for _ in range(N_TRIALS): - a = pa.get_action(0, timestep=0) + a = pa.get_action(0) if a == ("DONOTHING", {}): do_nothing_count += 1 elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): From 2c3652979bfd11c6e322c256b64658ec5f404847 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 4 Mar 2024 11:17:54 +0000 Subject: [PATCH 19/19] Add helpful error messages to action index errors --- src/primaite/game/agent/actions.py | 60 ++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 392d07c6..84bd3f39 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -812,6 +812,13 @@ class ActionManager: :return: The node hostname. :rtype: str """ + if not node_idx < len(self.node_names): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" + f"has {len(self.node_names)} nodes." + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.node_names[node_idx] def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: @@ -825,6 +832,13 @@ class ActionManager: :return: The name of the folder. Or None if the node has fewer folders than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" + f" is out of range for its action space. Folder on each node: {self.folder_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.folder_names[node_idx][folder_idx] def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: @@ -840,6 +854,17 @@ class ActionManager: fewer files than the given index. :rtype: Optional[str] """ + if ( + node_idx >= len(self.file_names) + or folder_idx >= len(self.file_names[node_idx]) + or file_idx >= len(self.file_names[node_idx][folder_idx]) + ): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" + f" but this is out of range for its action space. Files on each node: {self.file_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.file_names[node_idx][folder_idx][file_idx] def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: @@ -852,6 +877,13 @@ class ActionManager: :return: The name of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" + f" is out of range for its action space. Services on each node: {self.service_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.service_names[node_idx][service_idx] def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: @@ -864,6 +896,13 @@ class ActionManager: :return: The name of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " + f"this is out of range for its action space. Applications on each node: {self.application_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.application_names[node_idx][application_idx] def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: @@ -874,6 +913,13 @@ class ActionManager: :return: The protocol. :rtype: str """ + if protocol_idx >= len(self.protocols): + msg = ( + f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" + f" is out of range for its action space. Protocols: {self.protocols}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.protocols[protocol_idx] def get_ip_address_by_idx(self, ip_idx: int) -> str: @@ -885,6 +931,13 @@ class ActionManager: :return: The IP address. :rtype: str """ + if ip_idx >= len(self.ip_address_list): + msg = ( + f"Error: agent attempted to perform an action on ip address {ip_idx} but this" + f" is out of range for its action space. IP address list: {self.ip_address_list}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.ip_address_list[ip_idx] def get_port_by_idx(self, port_idx: int) -> str: @@ -896,6 +949,13 @@ class ActionManager: :return: The port. :rtype: str """ + if port_idx >= len(self.ports): + msg = ( + f"Error: agent attempted to perform an action on port {port_idx} but this" + f" is out of range for its action space. Port list: {self.ip_address_list}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.ports[port_idx] def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: