diff --git a/src/primaite/interface/request.py b/src/primaite/interface/request.py index 10ce6254..8e61c1cb 100644 --- a/src/primaite/interface/request.py +++ b/src/primaite/interface/request.py @@ -1,6 +1,9 @@ -from typing import Dict, Literal +from typing import Dict, ForwardRef, Literal -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, validate_call + +RequestResponse = ForwardRef("RequestResponse") +"""This makes it possible to type-hint RequestResponse.from_bool return type.""" class RequestResponse(BaseModel): @@ -9,21 +12,33 @@ class RequestResponse(BaseModel): model_config = ConfigDict(extra="forbid") """Cannot have extra fields in the response. Anything custom goes into the data field.""" - status: Literal["pending", "success", "failure"] = "pending" + status: Literal["pending", "success", "failure", "unreachable"] = "pending" """ What is the current status of the request: - pending - the request has not been received yet, or it has been received but it's still being processed. - - success - the request has successfully been received and processed. - - failure - the request could not reach it's intended target or it was rejected. - - Note that the failure status should only be used when the request cannot be processed, for instance when the - target SimComponent doesn't exist, or is in an OFF state that prevents it from accepting requests. If the - request is received by the target and the associated action is executed, but couldn't be completed due to - downstream factors, the request was still successfully received, it's just that the result wasn't what was - intended. + - success - the request has been received and executed successfully. + - failure - the request has been received and attempted, but execution failed. + - unreachable - the request could not reach it's intended target, either because it doesn't exist or the target + is off. """ data: Dict = {} """Catch-all place to provide any additional data that was generated as a response to the request.""" # TODO: currently, status and data have default values, because I don't want to interrupt existing functionality too # much. However, in the future we might consider making them mandatory. + + @classmethod + @validate_call + def from_bool(cls, status_bool: bool) -> RequestResponse: + """ + Construct a basic request response from a boolean. + + True maps to a success status. False maps to a failure status. + + :param status_bool: Whether to create a successful response + :type status_bool: bool + """ + if status_bool is True: + return cls(status="success", data={}) + elif status_bool is False: + return cls(status="failure", data={}) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 9ea59305..64f33f6a 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -96,7 +96,7 @@ class RequestManager(BaseModel): # _LOGGER.error(msg) # raise RuntimeError(msg) _LOGGER.debug(msg) - return RequestResponse(status="failure", data={"reason": msg}) + return RequestResponse(status="unreachable", data={"reason": msg}) request_type = self.request_types[request_key] @@ -226,7 +226,7 @@ class SimComponent(BaseModel): """ if self._request_manager is None: return - self._request_manager(request, context) + return self._request_manager(request, context) def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 8fd4e5d7..3ff73a80 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -7,6 +7,7 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_type import FileType @@ -41,12 +42,16 @@ class FileSystem(SimComponent): self._delete_manager.add_request( name="file", request_type=RequestType( - func=lambda request, context: self.delete_file(folder_name=request[0], file_name=request[1]) + func=lambda request, context: RequestResponse.from_bool( + self.delete_file(folder_name=request[0], file_name=request[1]) + ) ), ) self._delete_manager.add_request( name="folder", - request_type=RequestType(func=lambda request, context: self.delete_folder(folder_name=request[0])), + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0])) + ), ) rm.add_request( name="delete", @@ -57,12 +62,16 @@ class FileSystem(SimComponent): self._restore_manager.add_request( name="file", request_type=RequestType( - func=lambda request, context: self.restore_file(folder_name=request[0], file_name=request[1]) + func=lambda request, context: RequestResponse( + self.restore_file(folder_name=request[0], file_name=request[1]) + ) ), ) self._restore_manager.add_request( name="folder", - request_type=RequestType(func=lambda request, context: self.restore_folder(folder_name=request[0])), + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.restore_folder(folder_name=request[0])) + ), ) rm.add_request( name="restore", @@ -138,7 +147,7 @@ class FileSystem(SimComponent): ) return folder - def delete_folder(self, folder_name: str): + def delete_folder(self, folder_name: str) -> bool: """ Deletes a folder, removes it from the folders list and removes any child folders and files. @@ -146,24 +155,26 @@ class FileSystem(SimComponent): """ if folder_name == "root": self.sys_log.warning("Cannot delete the root folder.") - return + return False folder = self.get_folder(folder_name) - if folder: - # set folder to deleted state - folder.delete() - - # remove from folder list - self.folders.pop(folder.uuid) - - # add to deleted list - folder.remove_all_files() - - self.deleted_folders[folder.uuid] = folder - self.sys_log.info(f"Deleted folder /{folder.name} and its contents") - else: + if not folder: _LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}") + return False - def delete_folder_by_id(self, folder_uuid: str): + # set folder to deleted state + folder.delete() + + # remove from folder list + self.folders.pop(folder.uuid) + + # add to deleted list + folder.remove_all_files() + + self.deleted_folders[folder.uuid] = folder + self.sys_log.info(f"Deleted folder /{folder.name} and its contents") + return True + + def delete_folder_by_id(self, folder_uuid: str) -> None: """ Deletes a folder via its uuid. @@ -297,7 +308,7 @@ class FileSystem(SimComponent): return file - def delete_file(self, folder_name: str, file_name: str): + def delete_file(self, folder_name: str, file_name: str) -> bool: """ Delete a file by its name from a specific folder. @@ -309,8 +320,10 @@ class FileSystem(SimComponent): file = folder.get_file(file_name) if file: folder.remove_file(file) + return True + return False - def delete_file_by_id(self, folder_uuid: str, file_uuid: str): + def delete_file_by_id(self, folder_uuid: str, file_uuid: str) -> None: """ Deletes a file via its uuid. @@ -327,7 +340,7 @@ class FileSystem(SimComponent): else: self.sys_log.error(f"Unable to delete file that does not exist. (id: {file_uuid})") - def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): + def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str) -> None: """ Move a file from one folder to another. @@ -404,7 +417,7 @@ class FileSystem(SimComponent): # Agent actions ############################################################### - def scan(self, instant_scan: bool = False): + def scan(self, instant_scan: bool = False) -> None: """ Scan all the folders (and child files) in the file system. @@ -413,7 +426,7 @@ class FileSystem(SimComponent): for folder_id in self.folders: self.folders[folder_id].scan(instant_scan=instant_scan) - def reveal_to_red(self, instant_scan: bool = False): + def reveal_to_red(self, instant_scan: bool = False) -> None: """ Reveals all the folders (and child files) in the file system to the red agent. @@ -422,7 +435,7 @@ class FileSystem(SimComponent): for folder_id in self.folders: self.folders[folder_id].reveal_to_red(instant_scan=instant_scan) - def restore_folder(self, folder_name: str): + def restore_folder(self, folder_name: str) -> bool: """ Restore a folder. @@ -435,13 +448,14 @@ class FileSystem(SimComponent): if folder is None: self.sys_log.error(f"Unable to restore folder {folder_name}. Folder is not in deleted folder list.") - return + return False self.deleted_folders.pop(folder.uuid, None) folder.restore() self.folders[folder.uuid] = folder + return True - def restore_file(self, folder_name: str, file_name: str): + def restore_file(self, folder_name: str, file_name: str) -> bool: """ Restore a file. @@ -454,12 +468,15 @@ class FileSystem(SimComponent): :type: file_name: str """ folder = self.get_folder(folder_name=folder_name) + if not folder: + _LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.") + return False - if folder: - file = folder.get_file(file_name=file_name, include_deleted=True) + file = folder.get_file(file_name=file_name, include_deleted=True) - if file is None: - self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.") - return + if not file: + msg = f"Unable to restore file {file_name}. File was not found." + self.sys_log.error(msg) + return False - folder.restore_file(file_name=file_name) + return folder.restore_file(file_name=file_name) diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index fbe5f4b3..efac97c3 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Dict, Optional from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.system.core.sys_log import SysLog @@ -102,12 +103,26 @@ class FileSystemItemABC(SimComponent): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan())) - rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash())) - rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair())) - rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore())) + rm.add_request( + name="scan", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())) + ) + rm.add_request( + name="checkhash", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.check_hash())), + ) + rm.add_request( + name="repair", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.repair())), + ) + rm.add_request( + name="restore", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.restore())), + ) - rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt())) + rm.add_request( + name="corrupt", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.corrupt())), + ) return rm @@ -124,9 +139,9 @@ class FileSystemItemABC(SimComponent): return convert_size(self.size) @abstractmethod - def scan(self) -> None: + def scan(self) -> bool: """Scan the folder/file - updates the visible_health_status.""" - pass + return False @abstractmethod def reveal_to_red(self) -> None: @@ -134,7 +149,7 @@ class FileSystemItemABC(SimComponent): pass @abstractmethod - def check_hash(self) -> None: + def check_hash(self) -> bool: """ Checks the has of the file to detect any changes. @@ -142,30 +157,30 @@ class FileSystemItemABC(SimComponent): Return False if corruption is detected, otherwise True """ - pass + return False @abstractmethod - def repair(self) -> None: + def repair(self) -> bool: """ Repair the FileSystemItem. True if successfully repaired. False otherwise. """ - pass + return False @abstractmethod - def corrupt(self) -> None: + def corrupt(self) -> bool: """ Corrupt the FileSystemItem. True if successfully corrupted. False otherwise. """ - pass + return False @abstractmethod - def restore(self) -> None: + def restore(self) -> bool: """Restore the file/folder to the state before it got ruined.""" - pass + return False @abstractmethod def delete(self) -> None: diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 771dc7a0..9ef1ae59 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -5,6 +5,7 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus @@ -53,7 +54,9 @@ class Folder(FileSystemItemABC): rm = super()._init_request_manager() rm.add_request( name="delete", - request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])), + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.remove_file_by_name(file_name=request[0])) + ), ) self._file_request_manager = RequestManager() rm.add_request( @@ -249,6 +252,21 @@ class Folder(FileSystemItemABC): file = self.get_file_by_id(file_uuid=file_uuid) self.remove_file(file=file) + def remove_file_by_name(self, file_name: str) -> bool: + """ + Remove a file using its name. + + :param file_name: filename + :type file_name: str + :return: Whether it was successfully removed. + :rtype: bool + """ + for f in self.files.values(): + if f.name == file_name: + self.remove_file(f) + return True + return False + def remove_all_files(self): """Removes all the files in the folder.""" for file_id in self.files: @@ -258,7 +276,7 @@ class Folder(FileSystemItemABC): self.files = {} - def restore_file(self, file_name: str): + def restore_file(self, file_name: str) -> bool: """ Restores a file. @@ -268,13 +286,14 @@ class Folder(FileSystemItemABC): file = self.get_file(file_name=file_name, include_deleted=True) if not file: self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.") - return + return False file.restore() self.files[file.uuid] = file if file.deleted: self.deleted_files.pop(file.uuid) + return True def quarantine(self): """Quarantines the File System Folder.""" diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 82fae164..3349bed4 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -12,6 +12,7 @@ from pydantic import BaseModel, Field from primaite import getLogger from primaite.exceptions import NetworkError +from primaite.interface.request import RequestResponse from primaite.simulator import SIM_OUTPUT from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.domain.account import Account @@ -115,8 +116,8 @@ class NetworkInterface(SimComponent, ABC): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) - rm.add_request("disable", RequestType(func=lambda request, context: self.disable())) + rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) return rm @@ -140,14 +141,16 @@ class NetworkInterface(SimComponent, ABC): return state @abstractmethod - def enable(self): + def enable(self) -> bool: """Enable the interface.""" pass + return False @abstractmethod - def disable(self): + def disable(self) -> bool: """Disable the interface.""" pass + return False def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None: """ @@ -783,16 +786,28 @@ class Node(SimComponent): self._application_request_manager = RequestManager() rm.add_request("application", RequestType(func=self._application_request_manager)) - rm.add_request("scan", RequestType(func=lambda request, context: self.reveal_to_red())) + rm.add_request( + "scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red())) + ) - rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off())) - rm.add_request("startup", RequestType(func=lambda request, context: self.power_on())) - rm.add_request("reset", RequestType(func=lambda request, context: self.reset())) # TODO implement node reset - rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request - rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request + rm.add_request( + "shutdown", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_off())) + ) + rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on()))) + rm.add_request( + "reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset())) + ) # TODO implement node reset + rm.add_request( + "logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False)) + ) # TODO implement logon request + rm.add_request( + "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False)) + ) # TODO implement logoff request self._os_request_manager = RequestManager() - self._os_request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan())) + self._os_request_manager.add_request( + "scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())) + ) rm.add_request("os", RequestType(func=self._os_request_manager)) return rm @@ -973,7 +988,7 @@ class Node(SimComponent): self.file_system.apply_timestep(timestep=timestep) - def scan(self) -> None: + def scan(self) -> bool: """ Scan the node and all the items within it. @@ -987,8 +1002,9 @@ class Node(SimComponent): to the red agent. """ self.node_scan_countdown = self.node_scan_duration + return True - def reveal_to_red(self) -> None: + def reveal_to_red(self) -> bool: """ Reveals the node and all the items within it to the red agent. @@ -1002,34 +1018,40 @@ class Node(SimComponent): `revealed_to_red` to `True`. """ self.red_scan_countdown = self.node_scan_duration + return True - def power_on(self): + def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - if self.operating_state == NodeOperatingState.OFF: - self.operating_state = NodeOperatingState.BOOTING - self.start_up_countdown = self.start_up_duration - if self.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self._start_up_actions() self.sys_log.info("Power on") for network_interface in self.network_interfaces.values(): network_interface.enable() + return True + if self.operating_state == NodeOperatingState.OFF: + self.operating_state = NodeOperatingState.BOOTING + self.start_up_countdown = self.start_up_duration + return True - def power_off(self): + return False + + def power_off(self) -> bool: """Power off the Node, disabling its NICs if it is in the ON state.""" + if self.shut_down_duration <= 0: + self._shut_down_actions() + self.operating_state = NodeOperatingState.OFF + self.sys_log.info("Power off") + return True if self.operating_state == NodeOperatingState.ON: for network_interface in self.network_interfaces.values(): network_interface.disable() self.operating_state = NodeOperatingState.SHUTTING_DOWN self.shut_down_countdown = self.shut_down_duration + return True + return False - if self.shut_down_duration <= 0: - self._shut_down_actions() - self.operating_state = NodeOperatingState.OFF - self.sys_log.info("Power off") - - def reset(self): + def reset(self) -> bool: """ Resets the node. @@ -1040,6 +1062,8 @@ class Node(SimComponent): self.is_resetting = True self.sys_log.info("Resetting") self.power_off() + return True + return False def connect_nic(self, network_interface: NetworkInterface, port_name: Optional[str] = None): """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 7f7190fd..2fab4a3d 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable from pydantic import validate_call +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.hardware.base import IPWiredNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -308,19 +309,24 @@ class AccessControlList(SimComponent): rm.add_request( "add_rule", RequestType( - func=lambda request, context: self.add_rule( - action=ACLAction[request[0]], - protocol=None if request[1] == "ALL" else IPProtocol[request[1]], - src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), - src_port=None if request[3] == "ALL" else Port[request[3]], - dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]), - dst_port=None if request[5] == "ALL" else Port[request[5]], - position=int(request[6]), + func=lambda request, context: RequestResponse.from_bool( + self.add_rule( + action=ACLAction[request[0]], + protocol=None if request[1] == "ALL" else IPProtocol[request[1]], + src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), + src_port=None if request[3] == "ALL" else Port[request[3]], + dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]), + dst_port=None if request[5] == "ALL" else Port[request[5]], + position=int(request[6]), + ) ) ), ) - rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0])))) + rm.add_request( + "remove_rule", + RequestType(func=lambda request, context: RequestResponse.from_bool(self.remove_rule(int(request[0])))), + ) return rm def describe_state(self) -> Dict: @@ -366,7 +372,7 @@ class AccessControlList(SimComponent): src_port: Optional[Port] = None, dst_port: Optional[Port] = None, position: int = 0, - ) -> None: + ) -> bool: """ Adds a new ACL rule to control network traffic based on specified criteria. @@ -423,10 +429,12 @@ class AccessControlList(SimComponent): src_port=src_port, dst_port=dst_port, ) + return True else: raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.") + return False - def remove_rule(self, position: int) -> None: + def remove_rule(self, position: int) -> bool: """ Remove an ACL rule from a specific position. @@ -437,8 +445,10 @@ class AccessControlList(SimComponent): rule = self._acl[position] # noqa self._acl[position] = None del rule + return True else: raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.") + return False def is_permitted(self, frame: Frame) -> Tuple[bool, ACLRule]: """Check if a packet with the given properties is permitted through the ACL.""" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 7b259ff4..12148683 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional from uuid import uuid4 from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -37,7 +38,7 @@ class DatabaseClient(Application): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request("execute", RequestType(func=lambda request, context: self.execute())) + rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute()))) return rm def execute(self) -> bool: diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index ee98ea8e..f71b1465 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -4,6 +4,7 @@ from typing import Dict, Optional from primaite import getLogger from primaite.game.science import simulate_trial +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -76,7 +77,10 @@ class DataManipulationBot(Application): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack())) + rm.add_request( + name="execute", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())), + ) return rm @@ -179,21 +183,21 @@ class DataManipulationBot(Application): """ super().run() - def attack(self): + def attack(self) -> bool: """Perform the attack steps after opening the application.""" if not self._can_perform_action(): _LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.") self.run() - self._application_loop() + return self._application_loop() - def _application_loop(self): + def _application_loop(self) -> bool: """ The main application loop of the bot, handling the attack process. This is the core loop where the bot sequentially goes through the stages of the attack. """ if not self._can_perform_action(): - return + return False if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") self._logon() @@ -205,8 +209,12 @@ class DataManipulationBot(Application): DataManipulationAttackStage.FAILED, ): self.attack_stage = DataManipulationAttackStage.NOT_STARTED + + return True + else: self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.") + return False def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index 202fd189..05f87f03 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -4,6 +4,7 @@ from typing import Optional from primaite import getLogger from primaite.game.science import simulate_trial +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.database_client import DatabaseClient @@ -59,7 +60,10 @@ class DoSBot(DatabaseClient): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) + rm.add_request( + name="execute", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())), + ) return rm @@ -97,26 +101,26 @@ class DoSBot(DatabaseClient): f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}." ) - def run(self): + def run(self) -> bool: """Run the Denial of Service Bot.""" super().run() - self._application_loop() + return self._application_loop() - def _application_loop(self): + def _application_loop(self) -> bool: """ The main application loop for the Denial of Service bot. The loop goes through the stages of a DoS attack. """ if not self._can_perform_action(): - return + return False # DoS bot cannot do anything without a target if not self.target_ip_address or not self.target_port: self.sys_log.error( f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}" ) - return + return True self.clear_connections() self._perform_port_scan(p_of_success=self.port_scan_p_of_success) @@ -126,6 +130,7 @@ class DoSBot(DatabaseClient): self.attack_stage = DoSAttackStage.NOT_STARTED else: self.attack_stage = DoSAttackStage.COMPLETED + return True def _perform_port_scan(self, p_of_success: Optional[float] = 0.1): """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 9fa86328..5dee1dd5 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -6,6 +6,7 @@ from urllib.parse import urlparse from pydantic import BaseModel, ConfigDict from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.http import ( HttpRequestMethod, @@ -52,7 +53,10 @@ class WebBrowser(Application): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() rm.add_request( - name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa + name="execute", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.get_webpage()) + ), # noqa ) return rm diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 4102657c..706f166b 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Any, Dict, Optional from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState @@ -80,14 +81,14 @@ class Service(IOSoftware): def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() - rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) - rm.add_request("stop", RequestType(func=lambda request, context: self.stop())) - rm.add_request("start", RequestType(func=lambda request, context: self.start())) - rm.add_request("pause", RequestType(func=lambda request, context: self.pause())) - rm.add_request("resume", RequestType(func=lambda request, context: self.resume())) - rm.add_request("restart", RequestType(func=lambda request, context: self.restart())) - rm.add_request("disable", RequestType(func=lambda request, context: self.disable())) - rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) + rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))) + rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop()))) + rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start()))) + rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause()))) + rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume()))) + rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart()))) + rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) + rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) return rm @abstractmethod @@ -106,17 +107,19 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible.value return state - def stop(self) -> None: + def stop(self) -> bool: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Stopping service {self.name}") self.operating_state = ServiceOperatingState.STOPPED + return True + return False - def start(self, **kwargs) -> None: + def start(self, **kwargs) -> bool: """Start the service.""" # cant start the service if the node it is on is off if not super()._can_perform_action(): - return + return False if self.operating_state == ServiceOperatingState.STOPPED: self.sys_log.info(f"Starting service {self.name}") @@ -124,36 +127,47 @@ class Service(IOSoftware): # set software health state to GOOD if initially set to UNUSED if self.health_state_actual == SoftwareHealthState.UNUSED: self.set_health_state(SoftwareHealthState.GOOD) + return True + return False - def pause(self) -> None: + def pause(self) -> bool: """Pause the service.""" if self.operating_state == ServiceOperatingState.RUNNING: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.PAUSED + return True + return False - def resume(self) -> None: + def resume(self) -> bool: """Resume paused service.""" if self.operating_state == ServiceOperatingState.PAUSED: self.sys_log.info(f"Resuming service {self.name}") self.operating_state = ServiceOperatingState.RUNNING + return True + return False - def restart(self) -> None: + def restart(self) -> bool: """Restart running service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.RESTARTING self.restart_countdown = self.restart_duration + return True + return False - def disable(self) -> None: + def disable(self) -> bool: """Disable the service.""" self.sys_log.info(f"Disabling Application {self.name}") self.operating_state = ServiceOperatingState.DISABLED + return True - def enable(self) -> None: + def enable(self) -> bool: """Enable the disabled service.""" if self.operating_state == ServiceOperatingState.DISABLED: self.sys_log.info(f"Enabling Application {self.name}") self.operating_state = ServiceOperatingState.STOPPED + return True + return False def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 8864659c..2af53886 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -5,6 +5,7 @@ from enum import Enum from ipaddress import IPv4Address, IPv4Network from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from primaite.interface.request import RequestResponse from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -105,16 +106,18 @@ class Software(SimComponent): rm.add_request( "compromise", RequestType( - func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED), + func=lambda request, context: RequestResponse.from_bool( + self.set_health_state(SoftwareHealthState.COMPROMISED) + ), ), ) rm.add_request( "patch", RequestType( - func=lambda request, context: self.patch(), + func=lambda request, context: RequestResponse.from_bool(self.patch()), ), ) - rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) + rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))) return rm def _get_session_details(self, session_id: str) -> Session: @@ -148,7 +151,7 @@ class Software(SimComponent): ) return state - def set_health_state(self, health_state: SoftwareHealthState) -> None: + def set_health_state(self, health_state: SoftwareHealthState) -> bool: """ Assign a new health state to this software. @@ -160,6 +163,7 @@ class Software(SimComponent): :type health_state: SoftwareHealthState """ self.health_state_actual = health_state + return True def install(self) -> None: """ @@ -180,15 +184,18 @@ class Software(SimComponent): """ pass - def scan(self) -> None: + def scan(self) -> bool: """Update the observed health status to match the actual health status.""" self.health_state_visible = self.health_state_actual + return True - def patch(self) -> None: + def patch(self) -> bool: """Perform a patch on the software.""" if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD): self._patching_countdown = self.patching_duration self.set_health_state(SoftwareHealthState.PATCHING) + return True + return False def _update_patch_status(self) -> None: """Update the patch status of the software."""