diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 274da7aa..aff54d62 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -78,12 +78,12 @@ game_config: action_list: - type: DONOTHING # None: + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications} + self.verb: str + + def form_request(self, node_id: int, application_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_uuid = self.manager.get_node_uuid_by_idx(node_id) + application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id) + if node_uuid is None or application_uuid is None: + return ["do_nothing"] + return ["network", "node", node_uuid, "application", application_uuid, self.verb] + + +class NodeApplicationExecuteAction(NodeApplicationAbstractAction): + """Action which executes an application.""" + + def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications) + self.verb = "execute" + + class NodeFolderAbstractAction(AbstractAction): """ Base class for folder actions. @@ -536,6 +567,7 @@ class ActionManager: "NODE_SERVICE_RESTART": NodeServiceRestartAction, "NODE_SERVICE_DISABLE": NodeServiceDisableAction, "NODE_SERVICE_ENABLE": NodeServiceEnableAction, + "NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction, "NODE_FILE_SCAN": NodeFileScanAction, "NODE_FILE_CHECKHASH": NodeFileCheckhashAction, "NODE_FILE_DELETE": NodeFileDeleteAction, @@ -565,6 +597,7 @@ class ActionManager: max_folders_per_node: int = 2, # allows calculating shape max_files_per_folder: int = 2, # allows calculating shape max_services_per_node: int = 2, # allows calculating shape + max_applications_per_node: int = 10, # allows calculating shape max_nics_per_node: int = 8, # allows calculating shape max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol @@ -622,6 +655,7 @@ class ActionManager: "num_folders": max_folders_per_node, "num_files": max_files_per_folder, "num_services": max_services_per_node, + "num_applications": max_applications_per_node, "num_nics": max_nics_per_node, "num_acl_rules": max_acl_rules, "num_protocols": len(self.protocols), @@ -775,6 +809,21 @@ class ActionManager: service_uuids = list(node.services.keys()) return service_uuids[service_idx] if len(service_uuids) > service_idx else None + def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: + """Get the application UUID corresponding to the given node and service indices. + + :param node_idx: The index of the node. + :type node_idx: int + :param application_idx: The index of the service on the node. + :type application_idx: int + :return: The UUID of the service. Or None if the node has fewer services than the given index. + :rtype: Optional[str] + """ + node_uuid = self.get_node_uuid_by_idx(node_idx) + node = self.sim.network.nodes[node_uuid] + application_uuids = list(node.applications.keys()) + return application_uuids[application_idx] if len(application_uuids) > application_idx else None + def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: """Get the internet protocol corresponding to the given index. diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 5e73a423..33932df2 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -154,7 +154,17 @@ class DataManipulationAgent(AbstractScriptedAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.next_execution_timestep = self.agent_settings.start_settings.start_step + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) + + def _set_next_execution_timestep(self, timestep: int) -> None: + """Set the next execution timestep with a configured random variance. + + :param timestep: The timestep to add variance to. + """ + random_timestep_increment = random.randint( + -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + ) + self.next_execution_timestep = timestep + random_timestep_increment def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: """Randomly sample an action from the action space. @@ -166,21 +176,14 @@ class DataManipulationAgent(AbstractScriptedAgent): :return: _description_ :rtype: Tuple[str, Dict] """ - # TODO: Move this to the appropriate place - # return self.action_space.get_action(self.action_space.space.sample()) + current_timestep = self.action_space.session.step_counter - timestep = self.action_space.session.step_counter - - if timestep < self.next_execution_timestep: + if current_timestep < self.next_execution_timestep: return "DONOTHING", {"dummy": 0} - var = random.randint(-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance) - self.next_execution_timestep = timestep + self.agent_settings.start_settings.frequency + var + self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) - for bot in self.data_manipulation_bots: - bot.execute() - - return "DONOTHING", {"dummy": 0} + return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} class AbstractGATEAgent(AbstractAgent): diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py index e3f5b95d..f4b31cb1 100644 --- a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -3,6 +3,7 @@ from ipaddress import IPv4Address from typing import Optional from primaite.game.science import simulate_trial +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.database_client import DatabaseClient @@ -46,6 +47,13 @@ class DataManipulationBot(DatabaseClient): super().__init__(**kwargs) self.name = "DataManipulationBot" + def _init_request_manager(self) -> RequestManager: + rm = super()._init_request_manager() + + rm.add_request(name="execute", request_type=RequestType(func=self.execute)) + + return rm + def configure( self, server_ip_address: IPv4Address, diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index 5127254c..04e23e84 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -1,5 +1,3 @@ -from ipaddress import IPv4Address - import pytest from primaite.simulator.network.hardware.base import Node