Make all requests return a RequestResponse
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from typing import Dict, Literal
|
||||
from typing import Dict, ForwardRef, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, validate_call
|
||||
|
||||
RequestResponse = ForwardRef("RequestResponse")
|
||||
"""This makes it possible to type-hint RequestResponse.from_bool return type."""
|
||||
|
||||
|
||||
class RequestResponse(BaseModel):
|
||||
@@ -9,21 +12,33 @@ class RequestResponse(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
"""Cannot have extra fields in the response. Anything custom goes into the data field."""
|
||||
|
||||
status: Literal["pending", "success", "failure"] = "pending"
|
||||
status: Literal["pending", "success", "failure", "unreachable"] = "pending"
|
||||
"""
|
||||
What is the current status of the request:
|
||||
- pending - the request has not been received yet, or it has been received but it's still being processed.
|
||||
- success - the request has successfully been received and processed.
|
||||
- failure - the request could not reach it's intended target or it was rejected.
|
||||
|
||||
Note that the failure status should only be used when the request cannot be processed, for instance when the
|
||||
target SimComponent doesn't exist, or is in an OFF state that prevents it from accepting requests. If the
|
||||
request is received by the target and the associated action is executed, but couldn't be completed due to
|
||||
downstream factors, the request was still successfully received, it's just that the result wasn't what was
|
||||
intended.
|
||||
- success - the request has been received and executed successfully.
|
||||
- failure - the request has been received and attempted, but execution failed.
|
||||
- unreachable - the request could not reach it's intended target, either because it doesn't exist or the target
|
||||
is off.
|
||||
"""
|
||||
|
||||
data: Dict = {}
|
||||
"""Catch-all place to provide any additional data that was generated as a response to the request."""
|
||||
# TODO: currently, status and data have default values, because I don't want to interrupt existing functionality too
|
||||
# much. However, in the future we might consider making them mandatory.
|
||||
|
||||
@classmethod
|
||||
@validate_call
|
||||
def from_bool(cls, status_bool: bool) -> RequestResponse:
|
||||
"""
|
||||
Construct a basic request response from a boolean.
|
||||
|
||||
True maps to a success status. False maps to a failure status.
|
||||
|
||||
:param status_bool: Whether to create a successful response
|
||||
:type status_bool: bool
|
||||
"""
|
||||
if status_bool is True:
|
||||
return cls(status="success", data={})
|
||||
elif status_bool is False:
|
||||
return cls(status="failure", data={})
|
||||
|
||||
@@ -96,7 +96,7 @@ class RequestManager(BaseModel):
|
||||
# _LOGGER.error(msg)
|
||||
# raise RuntimeError(msg)
|
||||
_LOGGER.debug(msg)
|
||||
return RequestResponse(status="failure", data={"reason": msg})
|
||||
return RequestResponse(status="unreachable", data={"reason": msg})
|
||||
|
||||
request_type = self.request_types[request_key]
|
||||
|
||||
@@ -226,7 +226,7 @@ class SimComponent(BaseModel):
|
||||
"""
|
||||
if self._request_manager is None:
|
||||
return
|
||||
self._request_manager(request, context)
|
||||
return self._request_manager(request, context)
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
@@ -41,12 +42,16 @@ class FileSystem(SimComponent):
|
||||
self._delete_manager.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: self.delete_file(folder_name=request[0], file_name=request[1])
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.delete_file(folder_name=request[0], file_name=request[1])
|
||||
)
|
||||
),
|
||||
)
|
||||
self._delete_manager.add_request(
|
||||
name="folder",
|
||||
request_type=RequestType(func=lambda request, context: self.delete_folder(folder_name=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
name="delete",
|
||||
@@ -57,12 +62,16 @@ class FileSystem(SimComponent):
|
||||
self._restore_manager.add_request(
|
||||
name="file",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: self.restore_file(folder_name=request[0], file_name=request[1])
|
||||
func=lambda request, context: RequestResponse(
|
||||
self.restore_file(folder_name=request[0], file_name=request[1])
|
||||
)
|
||||
),
|
||||
)
|
||||
self._restore_manager.add_request(
|
||||
name="folder",
|
||||
request_type=RequestType(func=lambda request, context: self.restore_folder(folder_name=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.restore_folder(folder_name=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
name="restore",
|
||||
@@ -138,7 +147,7 @@ class FileSystem(SimComponent):
|
||||
)
|
||||
return folder
|
||||
|
||||
def delete_folder(self, folder_name: str):
|
||||
def delete_folder(self, folder_name: str) -> bool:
|
||||
"""
|
||||
Deletes a folder, removes it from the folders list and removes any child folders and files.
|
||||
|
||||
@@ -146,24 +155,26 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
return
|
||||
return False
|
||||
folder = self.get_folder(folder_name)
|
||||
if folder:
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
else:
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
|
||||
return False
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str):
|
||||
# set folder to deleted state
|
||||
folder.delete()
|
||||
|
||||
# remove from folder list
|
||||
self.folders.pop(folder.uuid)
|
||||
|
||||
# add to deleted list
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
return True
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str) -> None:
|
||||
"""
|
||||
Deletes a folder via its uuid.
|
||||
|
||||
@@ -297,7 +308,7 @@ class FileSystem(SimComponent):
|
||||
|
||||
return file
|
||||
|
||||
def delete_file(self, folder_name: str, file_name: str):
|
||||
def delete_file(self, folder_name: str, file_name: str) -> bool:
|
||||
"""
|
||||
Delete a file by its name from a specific folder.
|
||||
|
||||
@@ -309,8 +320,10 @@ class FileSystem(SimComponent):
|
||||
file = folder.get_file(file_name)
|
||||
if file:
|
||||
folder.remove_file(file)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_file_by_id(self, folder_uuid: str, file_uuid: str):
|
||||
def delete_file_by_id(self, folder_uuid: str, file_uuid: str) -> None:
|
||||
"""
|
||||
Deletes a file via its uuid.
|
||||
|
||||
@@ -327,7 +340,7 @@ class FileSystem(SimComponent):
|
||||
else:
|
||||
self.sys_log.error(f"Unable to delete file that does not exist. (id: {file_uuid})")
|
||||
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
|
||||
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str) -> None:
|
||||
"""
|
||||
Move a file from one folder to another.
|
||||
|
||||
@@ -404,7 +417,7 @@ class FileSystem(SimComponent):
|
||||
# Agent actions
|
||||
###############################################################
|
||||
|
||||
def scan(self, instant_scan: bool = False):
|
||||
def scan(self, instant_scan: bool = False) -> None:
|
||||
"""
|
||||
Scan all the folders (and child files) in the file system.
|
||||
|
||||
@@ -413,7 +426,7 @@ class FileSystem(SimComponent):
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].scan(instant_scan=instant_scan)
|
||||
|
||||
def reveal_to_red(self, instant_scan: bool = False):
|
||||
def reveal_to_red(self, instant_scan: bool = False) -> None:
|
||||
"""
|
||||
Reveals all the folders (and child files) in the file system to the red agent.
|
||||
|
||||
@@ -422,7 +435,7 @@ class FileSystem(SimComponent):
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].reveal_to_red(instant_scan=instant_scan)
|
||||
|
||||
def restore_folder(self, folder_name: str):
|
||||
def restore_folder(self, folder_name: str) -> bool:
|
||||
"""
|
||||
Restore a folder.
|
||||
|
||||
@@ -435,13 +448,14 @@ class FileSystem(SimComponent):
|
||||
|
||||
if folder is None:
|
||||
self.sys_log.error(f"Unable to restore folder {folder_name}. Folder is not in deleted folder list.")
|
||||
return
|
||||
return False
|
||||
|
||||
self.deleted_folders.pop(folder.uuid, None)
|
||||
folder.restore()
|
||||
self.folders[folder.uuid] = folder
|
||||
return True
|
||||
|
||||
def restore_file(self, folder_name: str, file_name: str):
|
||||
def restore_file(self, folder_name: str, file_name: str) -> bool:
|
||||
"""
|
||||
Restore a file.
|
||||
|
||||
@@ -454,12 +468,15 @@ class FileSystem(SimComponent):
|
||||
:type: file_name: str
|
||||
"""
|
||||
folder = self.get_folder(folder_name=folder_name)
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.")
|
||||
return False
|
||||
|
||||
if folder:
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
|
||||
if file is None:
|
||||
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
|
||||
return
|
||||
if not file:
|
||||
msg = f"Unable to restore file {file_name}. File was not found."
|
||||
self.sys_log.error(msg)
|
||||
return False
|
||||
|
||||
folder.restore_file(file_name=file_name)
|
||||
return folder.restore_file(file_name=file_name)
|
||||
|
||||
@@ -6,6 +6,7 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
@@ -102,12 +103,26 @@ class FileSystemItemABC(SimComponent):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash()))
|
||||
rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair()))
|
||||
rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore()))
|
||||
rm.add_request(
|
||||
name="scan", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
|
||||
)
|
||||
rm.add_request(
|
||||
name="checkhash",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.check_hash())),
|
||||
)
|
||||
rm.add_request(
|
||||
name="repair",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.repair())),
|
||||
)
|
||||
rm.add_request(
|
||||
name="restore",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.restore())),
|
||||
)
|
||||
|
||||
rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt()))
|
||||
rm.add_request(
|
||||
name="corrupt",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.corrupt())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -124,9 +139,9 @@ class FileSystemItemABC(SimComponent):
|
||||
return convert_size(self.size)
|
||||
|
||||
@abstractmethod
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Scan the folder/file - updates the visible_health_status."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def reveal_to_red(self) -> None:
|
||||
@@ -134,7 +149,7 @@ class FileSystemItemABC(SimComponent):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_hash(self) -> None:
|
||||
def check_hash(self) -> bool:
|
||||
"""
|
||||
Checks the has of the file to detect any changes.
|
||||
|
||||
@@ -142,30 +157,30 @@ class FileSystemItemABC(SimComponent):
|
||||
|
||||
Return False if corruption is detected, otherwise True
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def repair(self) -> None:
|
||||
def repair(self) -> bool:
|
||||
"""
|
||||
Repair the FileSystemItem.
|
||||
|
||||
True if successfully repaired. False otherwise.
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def corrupt(self) -> None:
|
||||
def corrupt(self) -> bool:
|
||||
"""
|
||||
Corrupt the FileSystemItem.
|
||||
|
||||
True if successfully corrupted. False otherwise.
|
||||
"""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def restore(self) -> None:
|
||||
def restore(self) -> bool:
|
||||
"""Restore the file/folder to the state before it got ruined."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
@@ -53,7 +54,9 @@ class Folder(FileSystemItemABC):
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="delete",
|
||||
request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])),
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.remove_file_by_name(file_name=request[0]))
|
||||
),
|
||||
)
|
||||
self._file_request_manager = RequestManager()
|
||||
rm.add_request(
|
||||
@@ -249,6 +252,21 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file_by_id(file_uuid=file_uuid)
|
||||
self.remove_file(file=file)
|
||||
|
||||
def remove_file_by_name(self, file_name: str) -> bool:
|
||||
"""
|
||||
Remove a file using its name.
|
||||
|
||||
:param file_name: filename
|
||||
:type file_name: str
|
||||
:return: Whether it was successfully removed.
|
||||
:rtype: bool
|
||||
"""
|
||||
for f in self.files.values():
|
||||
if f.name == file_name:
|
||||
self.remove_file(f)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_all_files(self):
|
||||
"""Removes all the files in the folder."""
|
||||
for file_id in self.files:
|
||||
@@ -258,7 +276,7 @@ class Folder(FileSystemItemABC):
|
||||
|
||||
self.files = {}
|
||||
|
||||
def restore_file(self, file_name: str):
|
||||
def restore_file(self, file_name: str) -> bool:
|
||||
"""
|
||||
Restores a file.
|
||||
|
||||
@@ -268,13 +286,14 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file(file_name=file_name, include_deleted=True)
|
||||
if not file:
|
||||
self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.")
|
||||
return
|
||||
return False
|
||||
|
||||
file.restore()
|
||||
self.files[file.uuid] = file
|
||||
|
||||
if file.deleted:
|
||||
self.deleted_files.pop(file.uuid)
|
||||
return True
|
||||
|
||||
def quarantine(self):
|
||||
"""Quarantines the File System Folder."""
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
@@ -115,8 +116,8 @@ class NetworkInterface(SimComponent, ABC):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
|
||||
return rm
|
||||
|
||||
@@ -140,14 +141,16 @@ class NetworkInterface(SimComponent, ABC):
|
||||
return state
|
||||
|
||||
@abstractmethod
|
||||
def enable(self):
|
||||
def enable(self) -> bool:
|
||||
"""Enable the interface."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def disable(self):
|
||||
def disable(self) -> bool:
|
||||
"""Disable the interface."""
|
||||
pass
|
||||
return False
|
||||
|
||||
def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None:
|
||||
"""
|
||||
@@ -783,16 +786,28 @@ class Node(SimComponent):
|
||||
self._application_request_manager = RequestManager()
|
||||
rm.add_request("application", RequestType(func=self._application_request_manager))
|
||||
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.reveal_to_red()))
|
||||
rm.add_request(
|
||||
"scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()))
|
||||
)
|
||||
|
||||
rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
|
||||
rm.add_request("reset", RequestType(func=lambda request, context: self.reset())) # TODO implement node reset
|
||||
rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
|
||||
rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
|
||||
rm.add_request(
|
||||
"shutdown", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_off()))
|
||||
)
|
||||
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
|
||||
rm.add_request(
|
||||
"reset", RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()))
|
||||
) # TODO implement node reset
|
||||
rm.add_request(
|
||||
"logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False))
|
||||
) # TODO implement logon request
|
||||
rm.add_request(
|
||||
"logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False))
|
||||
) # TODO implement logoff request
|
||||
|
||||
self._os_request_manager = RequestManager()
|
||||
self._os_request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
self._os_request_manager.add_request(
|
||||
"scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))
|
||||
)
|
||||
rm.add_request("os", RequestType(func=self._os_request_manager))
|
||||
|
||||
return rm
|
||||
@@ -973,7 +988,7 @@ class Node(SimComponent):
|
||||
|
||||
self.file_system.apply_timestep(timestep=timestep)
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""
|
||||
Scan the node and all the items within it.
|
||||
|
||||
@@ -987,8 +1002,9 @@ class Node(SimComponent):
|
||||
to the red agent.
|
||||
"""
|
||||
self.node_scan_countdown = self.node_scan_duration
|
||||
return True
|
||||
|
||||
def reveal_to_red(self) -> None:
|
||||
def reveal_to_red(self) -> bool:
|
||||
"""
|
||||
Reveals the node and all the items within it to the red agent.
|
||||
|
||||
@@ -1002,34 +1018,40 @@ class Node(SimComponent):
|
||||
`revealed_to_red` to `True`.
|
||||
"""
|
||||
self.red_scan_countdown = self.node_scan_duration
|
||||
return True
|
||||
|
||||
def power_on(self):
|
||||
def power_on(self) -> bool:
|
||||
"""Power on the Node, enabling its NICs if it is in the OFF state."""
|
||||
if self.operating_state == NodeOperatingState.OFF:
|
||||
self.operating_state = NodeOperatingState.BOOTING
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
|
||||
if self.start_up_duration <= 0:
|
||||
self.operating_state = NodeOperatingState.ON
|
||||
self._start_up_actions()
|
||||
self.sys_log.info("Power on")
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.enable()
|
||||
return True
|
||||
if self.operating_state == NodeOperatingState.OFF:
|
||||
self.operating_state = NodeOperatingState.BOOTING
|
||||
self.start_up_countdown = self.start_up_duration
|
||||
return True
|
||||
|
||||
def power_off(self):
|
||||
return False
|
||||
|
||||
def power_off(self) -> bool:
|
||||
"""Power off the Node, disabling its NICs if it is in the ON state."""
|
||||
if self.shut_down_duration <= 0:
|
||||
self._shut_down_actions()
|
||||
self.operating_state = NodeOperatingState.OFF
|
||||
self.sys_log.info("Power off")
|
||||
return True
|
||||
if self.operating_state == NodeOperatingState.ON:
|
||||
for network_interface in self.network_interfaces.values():
|
||||
network_interface.disable()
|
||||
self.operating_state = NodeOperatingState.SHUTTING_DOWN
|
||||
self.shut_down_countdown = self.shut_down_duration
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.shut_down_duration <= 0:
|
||||
self._shut_down_actions()
|
||||
self.operating_state = NodeOperatingState.OFF
|
||||
self.sys_log.info("Power off")
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> bool:
|
||||
"""
|
||||
Resets the node.
|
||||
|
||||
@@ -1040,6 +1062,8 @@ class Node(SimComponent):
|
||||
self.is_resetting = True
|
||||
self.sys_log.info("Resetting")
|
||||
self.power_off()
|
||||
return True
|
||||
return False
|
||||
|
||||
def connect_nic(self, network_interface: NetworkInterface, port_name: Optional[str] = None):
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import validate_call
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -308,19 +309,24 @@ class AccessControlList(SimComponent):
|
||||
rm.add_request(
|
||||
"add_rule",
|
||||
RequestType(
|
||||
func=lambda request, context: self.add_rule(
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.add_rule(
|
||||
action=ACLAction[request[0]],
|
||||
protocol=None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]),
|
||||
src_port=None if request[3] == "ALL" else Port[request[3]],
|
||||
dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]),
|
||||
dst_port=None if request[5] == "ALL" else Port[request[5]],
|
||||
position=int(request[6]),
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
|
||||
rm.add_request(
|
||||
"remove_rule",
|
||||
RequestType(func=lambda request, context: RequestResponse.from_bool(self.remove_rule(int(request[0])))),
|
||||
)
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
@@ -366,7 +372,7 @@ class AccessControlList(SimComponent):
|
||||
src_port: Optional[Port] = None,
|
||||
dst_port: Optional[Port] = None,
|
||||
position: int = 0,
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""
|
||||
Adds a new ACL rule to control network traffic based on specified criteria.
|
||||
|
||||
@@ -423,10 +429,12 @@ class AccessControlList(SimComponent):
|
||||
src_port=src_port,
|
||||
dst_port=dst_port,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.")
|
||||
return False
|
||||
|
||||
def remove_rule(self, position: int) -> None:
|
||||
def remove_rule(self, position: int) -> bool:
|
||||
"""
|
||||
Remove an ACL rule from a specific position.
|
||||
|
||||
@@ -437,8 +445,10 @@ class AccessControlList(SimComponent):
|
||||
rule = self._acl[position] # noqa
|
||||
self._acl[position] = None
|
||||
del rule
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.")
|
||||
return False
|
||||
|
||||
def is_permitted(self, frame: Frame) -> Tuple[bool, ACLRule]:
|
||||
"""Check if a packet with the given properties is permitted through the ACL."""
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -37,7 +38,7 @@ class DatabaseClient(Application):
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: self.execute()))
|
||||
rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute())))
|
||||
return rm
|
||||
|
||||
def execute(self) -> bool:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -76,7 +77,10 @@ class DataManipulationBot(Application):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack()))
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -179,21 +183,21 @@ class DataManipulationBot(Application):
|
||||
"""
|
||||
super().run()
|
||||
|
||||
def attack(self):
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
self._application_loop()
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop of the bot, handling the attack process.
|
||||
|
||||
This is the core loop where the bot sequentially goes through the stages of the attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self._logon()
|
||||
@@ -205,8 +209,12 @@ class DataManipulationBot(Application):
|
||||
DataManipulationAttackStage.FAILED,
|
||||
):
|
||||
self.attack_stage = DataManipulationAttackStage.NOT_STARTED
|
||||
|
||||
return True
|
||||
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
@@ -59,7 +60,10 @@ class DoSBot(DatabaseClient):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
|
||||
rm.add_request(
|
||||
name="execute",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -97,26 +101,26 @@ class DoSBot(DatabaseClient):
|
||||
f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}."
|
||||
)
|
||||
|
||||
def run(self):
|
||||
def run(self) -> bool:
|
||||
"""Run the Denial of Service Bot."""
|
||||
super().run()
|
||||
self._application_loop()
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self):
|
||||
def _application_loop(self) -> bool:
|
||||
"""
|
||||
The main application loop for the Denial of Service bot.
|
||||
|
||||
The loop goes through the stages of a DoS attack.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
return False
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return
|
||||
return True
|
||||
|
||||
self.clear_connections()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
@@ -126,6 +130,7 @@ class DoSBot(DatabaseClient):
|
||||
self.attack_stage = DoSAttackStage.NOT_STARTED
|
||||
else:
|
||||
self.attack_stage = DoSAttackStage.COMPLETED
|
||||
return True
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from urllib.parse import urlparse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
@@ -52,7 +53,10 @@ class WebBrowser(Application):
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa
|
||||
name="execute",
|
||||
request_type=RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.get_webpage())
|
||||
), # noqa
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
@@ -80,14 +81,14 @@ class Service(IOSoftware):
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("stop", RequestType(func=lambda request, context: self.stop()))
|
||||
rm.add_request("start", RequestType(func=lambda request, context: self.start()))
|
||||
rm.add_request("pause", RequestType(func=lambda request, context: self.pause()))
|
||||
rm.add_request("resume", RequestType(func=lambda request, context: self.resume()))
|
||||
rm.add_request("restart", RequestType(func=lambda request, context: self.restart()))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: self.disable()))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: self.enable()))
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop())))
|
||||
rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start())))
|
||||
rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause())))
|
||||
rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume())))
|
||||
rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart())))
|
||||
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
|
||||
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
|
||||
return rm
|
||||
|
||||
@abstractmethod
|
||||
@@ -106,17 +107,19 @@ class Service(IOSoftware):
|
||||
state["health_state_visible"] = self.health_state_visible.value
|
||||
return state
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self) -> bool:
|
||||
"""Stop the service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Stopping service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
def start(self, **kwargs) -> None:
|
||||
def start(self, **kwargs) -> bool:
|
||||
"""Start the service."""
|
||||
# cant start the service if the node it is on is off
|
||||
if not super()._can_perform_action():
|
||||
return
|
||||
return False
|
||||
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
self.sys_log.info(f"Starting service {self.name}")
|
||||
@@ -124,36 +127,47 @@ class Service(IOSoftware):
|
||||
# set software health state to GOOD if initially set to UNUSED
|
||||
if self.health_state_actual == SoftwareHealthState.UNUSED:
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
return True
|
||||
return False
|
||||
|
||||
def pause(self) -> None:
|
||||
def pause(self) -> bool:
|
||||
"""Pause the service."""
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.PAUSED
|
||||
return True
|
||||
return False
|
||||
|
||||
def resume(self) -> None:
|
||||
def resume(self) -> bool:
|
||||
"""Resume paused service."""
|
||||
if self.operating_state == ServiceOperatingState.PAUSED:
|
||||
self.sys_log.info(f"Resuming service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
return True
|
||||
return False
|
||||
|
||||
def restart(self) -> None:
|
||||
def restart(self) -> bool:
|
||||
"""Restart running service."""
|
||||
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
|
||||
self.sys_log.info(f"Pausing service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RESTARTING
|
||||
self.restart_countdown = self.restart_duration
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self) -> None:
|
||||
def disable(self) -> bool:
|
||||
"""Disable the service."""
|
||||
self.sys_log.info(f"Disabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.DISABLED
|
||||
return True
|
||||
|
||||
def enable(self) -> None:
|
||||
def enable(self) -> bool:
|
||||
"""Enable the disabled service."""
|
||||
if self.operating_state == ServiceOperatingState.DISABLED:
|
||||
self.sys_log.info(f"Enabling Application {self.name}")
|
||||
self.operating_state = ServiceOperatingState.STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
@@ -105,16 +106,18 @@ class Software(SimComponent):
|
||||
rm.add_request(
|
||||
"compromise",
|
||||
RequestType(
|
||||
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
),
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"patch",
|
||||
RequestType(
|
||||
func=lambda request, context: self.patch(),
|
||||
func=lambda request, context: RequestResponse.from_bool(self.patch()),
|
||||
),
|
||||
)
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
|
||||
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
|
||||
return rm
|
||||
|
||||
def _get_session_details(self, session_id: str) -> Session:
|
||||
@@ -148,7 +151,7 @@ class Software(SimComponent):
|
||||
)
|
||||
return state
|
||||
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> None:
|
||||
def set_health_state(self, health_state: SoftwareHealthState) -> bool:
|
||||
"""
|
||||
Assign a new health state to this software.
|
||||
|
||||
@@ -160,6 +163,7 @@ class Software(SimComponent):
|
||||
:type health_state: SoftwareHealthState
|
||||
"""
|
||||
self.health_state_actual = health_state
|
||||
return True
|
||||
|
||||
def install(self) -> None:
|
||||
"""
|
||||
@@ -180,15 +184,18 @@ class Software(SimComponent):
|
||||
"""
|
||||
pass
|
||||
|
||||
def scan(self) -> None:
|
||||
def scan(self) -> bool:
|
||||
"""Update the observed health status to match the actual health status."""
|
||||
self.health_state_visible = self.health_state_actual
|
||||
return True
|
||||
|
||||
def patch(self) -> None:
|
||||
def patch(self) -> bool:
|
||||
"""Perform a patch on the software."""
|
||||
if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD):
|
||||
self._patching_countdown = self.patching_duration
|
||||
self.set_health_state(SoftwareHealthState.PATCHING)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _update_patch_status(self) -> None:
|
||||
"""Update the patch status of the software."""
|
||||
|
||||
Reference in New Issue
Block a user