diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index f85baf10..a0e9667e 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -33,7 +33,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: GreenUC2Agent observation_space: type: UC2GreenObservation action_space: @@ -45,24 +45,39 @@ agents: - node_name: client_2 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - type: DUMMY agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: GreenUC2Agent observation_space: type: UC2GreenObservation action_space: @@ -74,14 +89,36 @@ agents: - node_name: client_1 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + reward_function: reward_components: - type: DUMMY + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + @@ -572,6 +609,14 @@ agents: weight: 0.33 options: node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.1 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.1 + options: + node_hostname: client_2 agent_settings: @@ -717,6 +762,10 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_1_dns_client type: DNSClient @@ -740,6 +789,10 @@ simulation: data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index b5d5f998..acc37711 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -242,6 +242,46 @@ class WebpageUnavailablePenalty(AbstractReward): return cls(node_hostname=node_hostname) +class GreenAdminDatabaseUnreachablePenalty(AbstractReward): + """Penalises the agent when the green db clients fail to connect to the database.""" + + def __init__(self, node_hostname: str) -> None: + """ + Initialise the reward component. + + :param node_hostname: Hostname of the node where the database client sits. + :type node_hostname: str + """ + self._node = node_hostname + self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + + def calculate(self, state: Dict) -> float: + """ + Calculate the reward based on current simulation state. + + :param state: The current state of the simulation. + :type state: Dict + """ + db_state = access_from_nested_dict(state, self.location_in_state) + if db_state is NOT_PRESENT_IN_STATE or "connections_status" not in db_state: + _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + connections_status = db_state["connections_status"] + if False in connections_status: + return -1.0 + return 0 + + @classmethod + def from_config(cls, config: Dict) -> AbstractReward: + """ + Build the reward component object from config. + + :param config: Configuration dictionary. + :type config: Dict + """ + node_hostname = config.get("node_hostname") + return cls(node_hostname=node_hostname) + + class RewardFunction: """Manages the reward function for the agent.""" @@ -250,6 +290,7 @@ class RewardFunction: "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, + "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, } def __init__(self): diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 50d9f3d4..67c0c9b4 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -1,8 +1,9 @@ from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from uuid import uuid4 from primaite import getLogger +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application @@ -25,6 +26,8 @@ class DatabaseClient(Application): server_password: Optional[str] = None connected: bool = False _query_success_tracker: Dict[str, bool] = {} + _connections_status: List[bool] = [] + """Keep track of connections that were established or verified during this step. Used for rewards.""" def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" @@ -33,6 +36,20 @@ class DatabaseClient(Application): super().__init__(**kwargs) self.set_original_state() + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + rm.add_request("execute", RequestType(func=lambda request, context: self.execute())) + return rm + + def execute(self) -> bool: + """Execution definition for db client: perform a select query.""" + if self.connections: + can_connect = self.connect(connection_id=list(self.connections.keys())[-1]) + else: + can_connect = self.connect() + self._connections_status.append(can_connect) + return can_connect + def set_original_state(self): """Sets the original state.""" _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") @@ -52,8 +69,11 @@ class DatabaseClient(Application): :return: A dictionary representing the current state. """ - pass - return super().describe_state() + state = super().describe_state() + # list of connections that were established or verified during this step. + state["connections_status"] = [c for c in self._connections_status] + self._connections_status.clear() + return state def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): """ @@ -74,6 +94,10 @@ class DatabaseClient(Application): if not connection_id: connection_id = str(uuid4()) + # if we are reusing a connection_id, remove it from self.connections so that its new status can be populated + # warning: janky + self._connections.pop(connection_id, None) + self.connected = self._connect( server_ip_address=self.server_ip_address, password=self.server_password, connection_id=connection_id )