diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index f32843b1..b8912883 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo ... - ref: green_agent_example team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive ``type`` -------- -Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour. +Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour. Available agent types: -- ``GreenWebBrowsingAgent`` +- ``ProbabilisticAgent`` - ``ProxyAgent`` - ``RedDatabaseCorruptingAgent`` diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 304621dd..9188733b 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -45,7 +45,7 @@ In a simulation, the bot can be controlled by using ``DataManipulationAgent`` wh Implementation ============== -The bot extends :ref:`DatabaseClient` and leverages its connectivity. +The bot connects to a :ref:`DatabaseClient` and leverages its connectivity. The host running ``DataManipulationBot`` must also have a :ref:`DatabaseClient` installed on it. - Uses the Application base class for lifecycle management. - Credentials, target IP and other options set via ``configure``. @@ -65,6 +65,7 @@ Python from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot + from primaite.simulator.system.applications.database_client import DatabaseClient client_1 = Computer( hostname="client_1", @@ -74,6 +75,7 @@ Python operating_state=NodeOperatingState.ON # initialise the computer in an ON state ) network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) + client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") @@ -148,6 +150,10 @@ If not using the data manipulation bot manually, it needs to be used with a data data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 + - ref: web_server_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 Configuration ============= diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst index 2c962c0a..dd6dec41 100644 --- a/docs/source/simulation_components/system/services/database_service.rst +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -25,6 +25,13 @@ Usage - Clients connect, execute queries, and disconnect. - Service runs on TCP port 5432 by default. +**Supported queries:** + +* ``SELECT``: As long as the database file is in a ``GOOD`` health state, the db service will respond with a 200 status code. +* ``DELETE``: This query represents an attack, it will cause the database file to enter a ``COMPROMISED`` state, and return a status code 200. +* ``INSERT``: If the database service is in a healthy state, this will return a 200 status, if it's not in a healthy state it will return 404. +* ``SELECT * FROM pg_stat_activity``: This query represents something an admin would send to check the status of the database. If the database service is in a healthy state, it returns a 200 status code, otherwise a 401 status code. + Implementation ============== diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 478124a9..8d1b4293 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -33,7 +33,12 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -45,24 +50,38 @@ agents: - node_name: client_2 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - type: DUMMY - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 - - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -74,10 +93,26 @@ agents: - node_name: client_1 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + reward_function: reward_components: - type: DUMMY @@ -85,6 +120,7 @@ agents: + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -572,6 +608,14 @@ agents: weight: 0.33 options: node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.1 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.1 + options: + node_hostname: client_2 agent_settings: @@ -722,6 +766,10 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_1_dns_client type: DNSClient @@ -745,6 +793,10 @@ simulation: data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 2b54eb37..260517b9 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -35,7 +35,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -64,7 +64,7 @@ agents: - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index b85cf86c..84bd3f39 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -607,7 +607,6 @@ class ActionManager: def __init__( self, - game: "PrimaiteGame", # reference to game for information lookup actions: List[Dict], # stores list of actions available to agent nodes: List[Dict], # extra configuration for each node max_folders_per_node: int = 2, # allows calculating shape @@ -618,7 +617,7 @@ class ActionManager: max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port - ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. + ip_address_list: List[str] = [], # to allow us to map an index to an ip address. act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions ) -> None: """Init method for ActionManager. @@ -649,7 +648,6 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.game: "PrimaiteGame" = game self.node_names: List[str] = [n["node_name"] for n in nodes] """List of node names in this action space. The list order is the mapping between node index and node name.""" self.application_names: List[List[str]] = [] @@ -707,25 +705,7 @@ class ActionManager: self.protocols: List[str] = protocols self.ports: List[str] = ports - self.ip_address_list: List[str] - - # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from - # the nodes in the simulation. - # TODO: refactor. Options: - # 1: This should be pulled out into it's own function for clarity - # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to - # go through the nodes here. - if ip_address_list is not None: - self.ip_address_list = ip_address_list - else: - self.ip_address_list = [] - for node_name in self.node_names: - node_obj = self.game.simulation.network.get_node_by_hostname(node_name) - if node_obj is None: - continue - network_interfaces = node_obj.network_interfaces - for nic_uuid, nic_obj in network_interfaces.items(): - self.ip_address_list.append(nic_obj.ip_address) + self.ip_address_list: List[str] = ip_address_list # action_args are settings which are applied to the action space as a whole. global_action_args = { @@ -832,6 +812,13 @@ class ActionManager: :return: The node hostname. :rtype: str """ + if not node_idx < len(self.node_names): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" + f"has {len(self.node_names)} nodes." + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.node_names[node_idx] def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: @@ -845,6 +832,13 @@ class ActionManager: :return: The name of the folder. Or None if the node has fewer folders than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" + f" is out of range for its action space. Folder on each node: {self.folder_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.folder_names[node_idx][folder_idx] def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: @@ -860,6 +854,17 @@ class ActionManager: fewer files than the given index. :rtype: Optional[str] """ + if ( + node_idx >= len(self.file_names) + or folder_idx >= len(self.file_names[node_idx]) + or file_idx >= len(self.file_names[node_idx][folder_idx]) + ): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" + f" but this is out of range for its action space. Files on each node: {self.file_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.file_names[node_idx][folder_idx][file_idx] def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: @@ -872,6 +877,13 @@ class ActionManager: :return: The name of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" + f" is out of range for its action space. Services on each node: {self.service_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.service_names[node_idx][service_idx] def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: @@ -884,6 +896,13 @@ class ActionManager: :return: The name of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " + f"this is out of range for its action space. Applications on each node: {self.application_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.application_names[node_idx][application_idx] def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: @@ -894,6 +913,13 @@ class ActionManager: :return: The protocol. :rtype: str """ + if protocol_idx >= len(self.protocols): + msg = ( + f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" + f" is out of range for its action space. Protocols: {self.protocols}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.protocols[protocol_idx] def get_ip_address_by_idx(self, ip_idx: int) -> str: @@ -905,6 +931,13 @@ class ActionManager: :return: The IP address. :rtype: str """ + if ip_idx >= len(self.ip_address_list): + msg = ( + f"Error: agent attempted to perform an action on ip address {ip_idx} but this" + f" is out of range for its action space. IP address list: {self.ip_address_list}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.ip_address_list[ip_idx] def get_port_by_idx(self, port_idx: int) -> str: @@ -916,6 +949,13 @@ class ActionManager: :return: The port. :rtype: str """ + if port_idx >= len(self.ports): + msg = ( + f"Error: agent attempted to perform an action on port {port_idx} but this" + f" is out of range for its action space. Port list: {self.ip_address_list}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.ports[port_idx] def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: @@ -958,6 +998,12 @@ class ActionManager: :return: The constructed ActionManager. :rtype: ActionManager """ + # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from + # the nodes in the simulation. + # TODO: refactor. Options: + # 1: This should be pulled out into it's own function for clarity + # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to + # go through the nodes here. ip_address_order = cfg["options"].pop("ip_address_order", {}) ip_address_list = [] for entry in ip_address_order: @@ -967,13 +1013,22 @@ class ActionManager: ip_address = node_obj.network_interface[nic_num].ip_address ip_address_list.append(ip_address) + if not ip_address_list: + node_names = [n["node_name"] for n in cfg.get("nodes", {})] + for node_name in node_names: + node_obj = game.simulation.network.get_node_by_hostname(node_name) + if node_obj is None: + continue + network_interfaces = node_obj.network_interfaces + for nic_uuid, nic_obj in network_interfaces.items(): + ip_address_list.append(nic_obj.ip_address) + obj = cls( - game=game, actions=cfg["action_list"], **cfg["options"], protocols=game.options.protocols, ports=game.options.ports, - ip_address_list=ip_address_list or None, + ip_address_list=ip_address_list, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 126c55ec..c758c926 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -26,22 +26,20 @@ class DataManipulationAgent(AbstractScriptedAgent): ) self.next_execution_timestep = timestep + random_timestep_increment - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. + def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: + """Waits until a specific timestep, then attempts to execute its data manipulation application. - :param obs: _description_ + :param obs: Current observation for this agent, not used in DataManipulationAgent :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ + :param timestep: The current simulation timestep, used for scheduling actions + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ - current_timestep = self.action_manager.game.step_counter - - if current_timestep < self.next_execution_timestep: + if timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} - self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) + self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 276715f7..88848479 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -112,7 +112,7 @@ class AbstractAgent(ABC): return self.reward_function.update(state) @abstractmethod - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -122,6 +122,8 @@ class AbstractAgent(ABC): :type obs: ObsType :param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted? :type reward: float, optional + :param timestep: The current timestep in the simulation, used for non-RL agents. Optional + :type timestep: int :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ @@ -144,20 +146,20 @@ class AbstractAgent(ABC): class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" - ... + pass class RandomAgent(AbstractScriptedAgent): """Agent that ignores its observation and acts completely at random.""" - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Sample the action space randomly. - :param obs: _description_ + :param obs: Current observation for this agent, not used in RandomAgent :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ + :param timestep: The current simulation timestep, not used in RandomAgent + :type timestep: int + :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ return self.action_manager.get_action(self.action_manager.space.sample()) @@ -183,14 +185,14 @@ class ProxyAgent(AbstractAgent): self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. :type obs: ObsType - :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. - :type reward: float, optional + :param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface. + :type timestep: int :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index ba6d1fa3..8c8e36ad 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -242,6 +242,48 @@ class WebpageUnavailablePenalty(AbstractReward): return cls(node_hostname=node_hostname) +class GreenAdminDatabaseUnreachablePenalty(AbstractReward): + """Penalises the agent when the green db clients fail to connect to the database.""" + + def __init__(self, node_hostname: str) -> None: + """ + Initialise the reward component. + + :param node_hostname: Hostname of the node where the database client sits. + :type node_hostname: str + """ + self._node = node_hostname + self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + + def calculate(self, state: Dict) -> float: + """ + Calculate the reward based on current simulation state. + + :param state: The current state of the simulation. + :type state: Dict + """ + db_state = access_from_nested_dict(state, self.location_in_state) + if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: + _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + last_connection_successful = db_state["last_connection_successful"] + if last_connection_successful is False: + return -1.0 + elif last_connection_successful is True: + return 1.0 + return 0.0 + + @classmethod + def from_config(cls, config: Dict) -> AbstractReward: + """ + Build the reward component object from config. + + :param config: Configuration dictionary. + :type config: Dict + """ + node_hostname = config.get("node_hostname") + return cls(node_hostname=node_hostname) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -250,6 +292,7 @@ class RewardFunction: "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, + "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, } """List of reward class identifiers.""" diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py index 3748494b..5111df32 100644 --- a/src/primaite/game/agent/scripted_agents.py +++ b/src/primaite/game/agent/scripted_agents.py @@ -1,14 +1,87 @@ """Agents with predefined behaviours.""" +from typing import Dict, Optional, Tuple + +import numpy as np +import pydantic +from gymnasium.core import ObsType + +from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.rewards import RewardFunction -class GreenWebBrowsingAgent(AbstractScriptedAgent): - """Scripted agent which attempts to send web requests to a target node.""" +class ProbabilisticAgent(AbstractScriptedAgent): + """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" - raise NotImplementedError + class Settings(pydantic.BaseModel): + """Config schema for Probabilistic agent settings.""" + model_config = pydantic.ConfigDict(extra="forbid") + """Strict validation.""" + action_probabilities: Dict[int, float] + """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" + random_seed: Optional[int] = None + """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" + # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way + # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. -class RedDatabaseCorruptingAgent(AbstractScriptedAgent): - """Scripted agent which attempts to corrupt the database of the target node.""" + @pydantic.field_validator("action_probabilities", mode="after") + @classmethod + def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: + """Make sure the probabilities sum to 1.""" + if not abs(sum(v.values()) - 1) < 1e-6: + raise ValueError("Green action probabilities must sum to 1") + return v - raise NotImplementedError + @pydantic.field_validator("action_probabilities", mode="after") + @classmethod + def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]: + """Ensure that the keys of the probability dictionary cover all integers from 0 to N.""" + if not all((i in v) for i in range(len(v))): + raise ValueError( + "Green action probabilities must be defined as a mapping where the keys are consecutive integers " + "from 0 to N." + ) + return v + + def __init__( + self, + agent_name: str, + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + settings: Dict = {}, + ) -> None: + # If the action probabilities are not specified, create equal probabilities for all actions + if "action_probabilities" not in settings: + num_actions = len(action_space.action_map) + settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} + + # If seed not specified, set it to None so that numpy chooses a random one. + settings.setdefault("random_seed") + + self.settings = ProbabilisticAgent.Settings(**settings) + + self.rng = np.random.default_rng(self.settings.random_seed) + + # convert probabilities from + self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) + + super().__init__(agent_name, action_space, observation_space, reward_function) + + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """ + Sample the action space randomly. + + The probability of each action is given by the corresponding index in ``self.probabilities``. + + :param obs: Current observation for this agent, not used in ProbabilisticAgent + :type obs: ObsType + :param timestep: The current simulation timestep, not used in ProbabilisticAgent + :type timestep: int + :return: Action formatted in CAOS format + :rtype: Tuple[str, Dict] + """ + choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) + return self.action_manager.get_action(choice) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 2659abef..cd88d832 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,9 +7,10 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent +from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents import ProbabilisticAgent from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -164,8 +165,7 @@ class PrimaiteGame: agent_actions = {} for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation - rew = agent.reward_function.current_reward - action_choice, options = agent.get_action(obs, rew) + action_choice, options = agent.get_action(obs, timestep=self.step_counter) agent_actions[agent.agent_name] = (action_choice, options) request = agent.format_request(action_choice, options) self.simulation.apply_request(request) @@ -299,7 +299,7 @@ class PrimaiteGame: if service_type == "DatabaseService": if "options" in service_cfg: opt = service_cfg["options"] - new_service.password = opt.get("backup_server_ip", None) + new_service.password = opt.get("db_password", None) new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) if service_type == "FTPServer": if "options" in service_cfg: @@ -412,20 +412,19 @@ class PrimaiteGame: # CREATE REWARD FUNCTION reward_function = RewardFunction.from_config(reward_function_cfg) - # OTHER AGENT SETTINGS - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - # CREATE AGENT - if agent_type == "GreenWebBrowsingAgent": + if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing - new_agent = RandomAgent( + settings = agent_cfg.get("agent_settings") + new_agent = ProbabilisticAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=reward_function, - agent_settings=agent_settings, + settings=settings, ) elif agent_type == "ProxyAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -435,6 +434,8 @@ class PrimaiteGame: ) game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) + new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], action_space=action_space, diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/uc2_demo.ipynb index c8f2595b..94be8baa 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/uc2_demo.ipynb @@ -13,7 +13,7 @@ "source": [ "## Scenario\n", "\n", - "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n", + "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database. Occasionally, admins need to access the database directly from the clients.\n", "\n", "[](_package_data/uc2_network.png)\n", "\n", @@ -46,7 +46,9 @@ "source": [ "## Green agent\n", "\n", - "There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available." + "There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available.\n", + "\n", + "Sometimes, the green agents send a request directly to the database to check that it is reachable." ] }, { @@ -68,7 +70,9 @@ "source": [ "## Blue agent\n", "\n", - "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router." + "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router.\n", + "\n", + "However, these rules will also impact greens' ability to check the database connection. The blue agent should only block the infected client, it should let the other client connect freely. Once the attack has begun, automated traffic monitoring will detect it as suspicious network traffic. The blue agent's observation space will show this as an increase in the number of malicious network events (NMNE) on one of the network interfaces. To achieve optimal reward, the agent should only block the client which has the non-zero outbound NMNE." ] }, { @@ -101,9 +105,11 @@ "The red agent does not use information about the state of the network to decide its action.\n", "\n", "### Green\n", - "The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n", + "The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, it will do nothing 30% of the time, send a web request 60% of the time, and send a db status check 10% of the time.\n", "\n", - "When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender." + "When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender.\n", + "\n", + "Also, when the green agent is blocked from checking the database status, it causes a small negative reward." ] }, { @@ -322,9 +328,10 @@ "source": [ "## Reward Function\n", "\n", - "The blue agent's reward is calculated using two measures:\n", + "The blue agent's reward is calculated using these measures:\n", "1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n", "2. Whether each green agents' most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n", + "3. Whether each green agents' most recent DB status check was successful (+1 for a successful connection, -1 for no connection).\n", "\n", "The file status reward and the two green-agent-related rewards are averaged to get a total step reward.\n" ] @@ -346,7 +353,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "%load_ext autoreload\n", @@ -356,7 +365,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Imports\n", @@ -403,7 +414,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -1.0 when green agents try to access the webpage." + "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -0.8 when green agents try to access the webpage." ] }, { @@ -432,7 +443,7 @@ "source": [ "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", - " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" ] }, { @@ -477,6 +488,13 @@ "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) has increased from 0 to 1. This tells us which client is being used by the red agent." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -506,7 +524,7 @@ "\n", "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n", "\n", - "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { @@ -518,9 +536,9 @@ "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n", - "print(f\"Blue reward:{reward}\" )" + "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", + "print(f\"Blue reward:{reward:.2f}\" )" ] }, { @@ -539,24 +557,24 @@ "outputs": [], "source": [ "env.step(13) # Patch the database\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(26) # Block client 1\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", "\n", "env.step(27) # Block client 2\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )\n", "\n", "for step in range(30):\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now, even though the red agent executes an attack, the reward stays at 1.0" + "Now, even though the red agent executes an attack, the reward stays at 0.8." ] }, { @@ -575,6 +593,46 @@ "obs['ACL']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can slightly increase the reward by unblocking the client which isn't being used by the attacker. If node 6 has outbound NMNEs, let's unblock client 2, and if node 7 has outbound NMNEs, let's unblock client 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if obs['NODES'][6]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", + " # client 1 has NMNEs, let's unblock client 2\n", + " env.step(34) # remove ACL rule 6\n", + "elif obs['NODES'][7]['NETWORK_INTERFACES'][1]['nmne']['outbound'] == 1:\n", + " env.step(33) # remove ACL rule 5\n", + "else:\n", + " print(\"something went wrong, neither client has NMNEs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the reward will eventually increase to 1.0, even after red agent attempts subsequent attacks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for step in range(30):\n", + " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'][0]}, Blue reward:{reward:.2f}\" )" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -590,13 +648,6 @@ "source": [ "env.reset()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index f8dbab9d..d54503a3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,3 +1,4 @@ +import copy import json from typing import Any, Dict, Optional, SupportsFloat, Tuple @@ -23,7 +24,7 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.game_config: Dict = game_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" @@ -78,7 +79,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index b8f80e95..d244f6b0 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict -from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings @@ -40,7 +39,7 @@ class SessionMode(Enum): class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" - def __init__(self, game: PrimaiteGame): + def __init__(self, game_cfg: Dict): """Initialise PrimaiteSession object.""" self.training_options: TrainingOptions """Options specific to agent training.""" @@ -57,7 +56,7 @@ class PrimaiteSession: self.io_manager: Optional["SessionIO"] = None """IO manager for the session.""" - self.game: PrimaiteGame = game + self.game_cfg: Dict = game_cfg """Primaite Game object for managing main simulation loop and agents.""" def start_session(self) -> None: @@ -93,9 +92,7 @@ class PrimaiteSession: io_settings = cfg.get("io_settings", {}) io_manager = SessionIO(SessionIOSettings(**io_settings)) - game = PrimaiteGame.from_config(cfg) - - sess = cls(game=game) + sess = cls(game_cfg=cfg) sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index f82dee4a..fa9d86ef 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -146,6 +146,9 @@ def arcd_uc2_network() -> Network: ) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) + db_client_1 = client_1.software_manager.install(DatabaseClient) + db_client_1 = client_1.software_manager.software.get("DatabaseClient") + db_client_1.run() client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") db_manipulation_bot.configure( @@ -165,6 +168,9 @@ def arcd_uc2_network() -> Network: start_up_duration=0, ) client_2.power_on() + client_2.software_manager.install(DatabaseClient) + db_client_2 = client_2.software_manager.software.get("DatabaseClient") + db_client_2.run() web_browser = client_2.software_manager.software.get("WebBrowser") web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) @@ -194,67 +200,10 @@ def arcd_uc2_network() -> Network: database_server.power_on() network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3]) - ddl = """ - CREATE TABLE IF NOT EXISTS user ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name VARCHAR(50) NOT NULL, - email VARCHAR(50) NOT NULL, - age INT, - city VARCHAR(50), - occupation VARCHAR(50) - );""" - - user_insert_statements = [ - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", - # noqa - ] database_server.software_manager.install(DatabaseService) database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) - database_service._process_sql(ddl, None, None) # noqa - for insert_statement in user_insert_statements: - database_service._process_sql(insert_statement, None, None) # noqa # Web Server web_server = Server( diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index dc848ade..8ee0b4af 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -15,6 +15,8 @@ class IPProtocol(Enum): .. _List of IPProtocols: """ + NONE = "none" + """Placeholder for a non-protocol.""" TCP = "tcp" """Transmission Control Protocol.""" UDP = "udp" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 57fd3a46..7b259ff4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional from uuid import uuid4 from primaite import getLogger +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application @@ -25,6 +26,8 @@ class DatabaseClient(Application): server_password: Optional[str] = None connected: bool = False _query_success_tracker: Dict[str, bool] = {} + _last_connection_successful: Optional[bool] = None + """Keep track of connections that were established or verified during this step. Used for rewards.""" def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" @@ -32,14 +35,30 @@ class DatabaseClient(Application): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + rm.add_request("execute", RequestType(func=lambda request, context: self.execute())) + return rm + + def execute(self) -> bool: + """Execution definition for db client: perform a select query.""" + if self.connections: + can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1]) + else: + can_connect = self.check_connection(connection_id=str(uuid4())) + self._last_connection_successful = can_connect + return can_connect + def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. :return: A dictionary representing the current state. """ - pass - return super().describe_state() + state = super().describe_state() + # list of connections that were established or verified during this step. + state["last_connection_successful"] = self._last_connection_successful + return state def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): """ @@ -65,6 +84,18 @@ class DatabaseClient(Application): ) return self.connected + def check_connection(self, connection_id: str) -> bool: + """Check whether the connection can be successfully re-established. + + :param connection_id: connection ID to check + :type connection_id: str + :return: Whether the connection was successfully re-established. + :rtype: bool + """ + if not self._can_perform_action(): + return False + return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id) + def _connect( self, server_ip_address: IPv4Address, diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index 5fe951b7..ee98ea8e 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -1,10 +1,13 @@ from enum import IntEnum from ipaddress import IPv4Address -from typing import Optional +from typing import Dict, Optional from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -32,12 +35,10 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(DatabaseClient): +class DataManipulationBot(Application): """A bot that simulates a script which performs a SQL injection attack.""" - server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None - server_password: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 @@ -46,8 +47,31 @@ class DataManipulationBot(DatabaseClient): "Whether to repeat attacking once finished." def __init__(self, **kwargs): + kwargs["name"] = "DataManipulationBot" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) - self.name = "DataManipulationBot" + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + state = super().describe_state() + return state + + @property + def _host_db_client(self) -> DatabaseClient: + """Return the database client that is installed on the same machine as the DataManipulationBot.""" + db_client = self.software_manager.software.get("DatabaseClient") + if db_client is None: + _LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.") + return db_client def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() @@ -76,8 +100,8 @@ class DataManipulationBot(DatabaseClient): :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address - self.payload = payload self.server_password = server_password + self.payload = payload self.port_scan_p_of_success = port_scan_p_of_success self.data_manipulation_p_of_success = data_manipulation_p_of_success self.repeat = repeat @@ -123,15 +147,21 @@ class DataManipulationBot(DatabaseClient): :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. """ + if self._host_db_client is None: + self.attack_stage = DataManipulationAttackStage.FAILED + return + + self._host_db_client.server_ip_address = self.server_ip_address + self._host_db_client.server_password = self.server_password if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack - if not len(self.connections): - self.connect() - if len(self.connections): - self.query(self.payload) + if not len(self._host_db_client.connections): + self._host_db_client.connect() + if len(self._host_db_client.connections): + self._host_db_client.query(self.payload) self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True if attack_successful: diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 9dac6b25..202fd189 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -6,7 +6,6 @@ from primaite import getLogger from primaite.game.science import simulate_trial from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -28,7 +27,7 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient, Application): +class DoSBot(DatabaseClient): """A bot that simulates a Denial of Service attack.""" target_ip_address: Optional[IPv4Address] = None diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 6f2c479c..9fa86328 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -199,7 +199,7 @@ class WebBrowser(Application): def state(self) -> Dict: """Return the contents of this dataclass as a dict for use with describe_state method.""" if self.status == self._HistoryItemStatus.LOADED: - outcome = self.response_code + outcome = self.response_code.value else: outcome = self.status.value return {"url": self.url, "outcome": outcome} diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 9fdfd5ff..c73132eb 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -221,6 +221,18 @@ class DatabaseService(Service): } else: return {"status_code": 404, "data": False} + elif query == "SELECT * FROM pg_stat_activity": + # Check if the connection is active. + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 401, "data": False} else: # Invalid query self.sys_log.info(f"{self.name}: Invalid {query}") diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index c76aeef6..38d54ce3 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -21,7 +21,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml index 71dc31a7..9d7b34cb 100644 --- a/tests/assets/configs/basic_firewall.yaml +++ b/tests/assets/configs/basic_firewall.yaml @@ -40,7 +40,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index daa40aa7..9a0d5313 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -40,7 +40,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 880735d9..95e09e16 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -65,7 +65,7 @@ game: agents: - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 1cb59f87..f2815578 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index b1b15372..8bbddb76 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -31,7 +31,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index e5f9d544..199cf8cc 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -29,7 +29,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index 10e088d8..71a23989 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/conftest.py b/tests/conftest.py index a8c3bdd1..a117a1ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from datetime import datetime from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import pytest import yaml @@ -328,7 +328,7 @@ class ControlledAgent(AbstractAgent): ) self.most_recent_action: Tuple[str, Dict] - def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" return self.most_recent_action @@ -497,7 +497,6 @@ def game_and_agent(): ] action_space = ActionManager( - game=game, actions=actions, # ALL POSSIBLE ACTIONS nodes=[ { diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 7785e4ae..da13dcd8 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -21,15 +21,15 @@ class TestPrimaiteSession: raise AssertionError assert session is not None - assert session.game.simulation - assert len(session.game.agents) == 3 - assert len(session.game.rl_agents) == 1 + assert session.env.game.simulation + assert len(session.env.game.agents) == 3 + assert len(session.env.game.rl_agents) == 1 assert session.policy assert session.env - assert session.game.simulation.network - assert len(session.game.simulation.network.nodes) == 10 + assert session.env.game.simulation.network + assert len(session.env.game.simulation.network.nodes) == 10 @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_start_session(self, temp_primaite_session): diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index fd8a89a4..8edbf0ac 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,7 +1,10 @@ -from primaite.game.agent.rewards import WebpageUnavailablePenalty +from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty +from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService from tests.conftest import ControlledAgent @@ -35,3 +38,45 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent.store_action(action) game.step() assert agent.reward_function.current_reward == -0.7 + + +def test_uc2_rewards(game_and_agent): + """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" + game, agent = game_and_agent + agent: ControlledAgent + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + server_1.software_manager.install(DatabaseService) + db_service = server_1.software_manager.software.get("DatabaseService") + db_service.start() + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) + db_client.run() + + router: Router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + + comp = GreenAdminDatabaseUnreachablePenalty("client_1") + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate(state) + assert reward_value == 1.0 + + router.acl.remove_rule(position=2) + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate(state) + assert reward_value == -1.0 diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py new file mode 100644 index 00000000..73228e36 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -0,0 +1,84 @@ +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents import ProbabilisticAgent + + +def test_probabilistic_agent(): + """ + Check that the probabilistic agent selects actions with approximately the right probabilities. + + Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator), + we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values + were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number + generation. + """ + N_TRIALS = 10_000 + P_DO_NOTHING = 0.1 + P_NODE_APPLICATION_EXECUTE = 0.3 + P_NODE_FILE_DELETE = 0.6 + MIN_DO_NOTHING = 850 + MAX_DO_NOTHING = 1150 + MIN_NODE_APPLICATION_EXECUTE = 2800 + MAX_NODE_APPLICATION_EXECUTE = 3200 + MIN_NODE_FILE_DELETE = 5750 + MAX_NODE_FILE_DELETE = 6250 + + action_space = ActionManager( + actions=[ + {"type": "DONOTHING"}, + {"type": "NODE_APPLICATION_EXECUTE"}, + {"type": "NODE_FILE_DELETE"}, + ], + nodes=[ + { + "node_name": "client_1", + "applications": [{"application_name": "WebBrowser"}], + "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], + }, + ], + max_folders_per_node=2, + max_files_per_folder=2, + max_services_per_node=2, + max_applications_per_node=2, + max_nics_per_node=2, + max_acl_rules=10, + protocols=["TCP", "UDP", "ICMP"], + ports=["HTTP", "DNS", "ARP"], + act_map={ + 0: {"action": "DONOTHING", "options": {}}, + 1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}}, + 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, + }, + ) + observation_space = ObservationManager(ICSObservation()) + reward_function = RewardFunction() + + pa = ProbabilisticAgent( + agent_name="test_agent", + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + settings={ + "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + "random_seed": 120, + }, + ) + + do_nothing_count = 0 + node_application_execute_count = 0 + node_file_delete_count = 0 + for _ in range(N_TRIALS): + a = pa.get_action(0) + if a == ("DONOTHING", {}): + do_nothing_count += 1 + elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): + node_application_execute_count += 1 + elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}): + node_file_delete_count += 1 + else: + raise AssertionError("Probabilistic agent produced an unexpected action.") + + assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING + assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE + assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 2ca67119..6d00886a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -26,8 +26,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.POSTGRES_SERVER - assert data_manipulation_bot.protocol == IPProtocol.TCP + assert data_manipulation_bot.port == Port.NONE + assert data_manipulation_bot.protocol == IPProtocol.NONE assert data_manipulation_bot.payload == "DELETE" @@ -70,4 +70,13 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot._perform_data_manipulation(p_of_success=1.0) assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) - assert len(dm_bot.connections) + assert len(dm_bot._host_db_client.connections) + + +def test_dm_bot_fails_without_db_client(dm_client): + dm_client.software_manager.uninstall("DatabaseClient") + dm_bot = dm_client.software_manager.software.get("DataManipulationBot") + assert dm_bot._host_db_client is None + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + dm_bot._perform_data_manipulation(p_of_success=1.0) + assert dm_bot.attack_stage is DataManipulationAttackStage.FAILED