Merge branch 'feature/1812-traverse-actions-dict' into feature/1947-implement-missing-node-actions

This commit is contained in:
Czar.Echavez
2023-10-10 08:58:58 +01:00
58 changed files with 2111 additions and 950 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -11,9 +11,9 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
class ActionPermissionValidator(BaseModel):
class RequestPermissionValidator(BaseModel):
"""
Base class for action validators.
Base class for request validators.
The permissions manager is designed to be generic. So, although in the first instance the permissions
are evaluated purely on membership to AccountGroup, this class can support validating permissions based on any
@@ -22,130 +22,127 @@ class ActionPermissionValidator(BaseModel):
@abstractmethod
def __call__(self, request: List[str], context: Dict) -> bool:
"""Use the request and context paramters to decide whether the action should be permitted."""
"""Use the request and context paramters to decide whether the request should be permitted."""
pass
class AllowAllValidator(ActionPermissionValidator):
"""Always allows the action."""
class AllowAllValidator(RequestPermissionValidator):
"""Always allows the request."""
def __call__(self, request: List[str], context: Dict) -> bool:
"""Always allow the action."""
"""Always allow the request."""
return True
class Action(BaseModel):
class RequestType(BaseModel):
"""
This object stores data related to a single action.
This object stores data related to a single request type.
This includes the callable that can execute the action request, and the validator that will decide whether
the action can be performed or not.
This includes the callable that can execute the request, and the validator that will decide whether
the request can be performed or not.
"""
func: Callable[[List[str], Dict], None]
"""
``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
that invokes a class method of your SimComponent. For example if the component is a node and the request type is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
Then, this request will be given something like ``func = lambda request, context: self.turn_off()``.
``func`` can also be another action manager, since ActionManager is a callable with a signature that matches what is
``func`` can also be another request manager, since RequestManager is a callable with a signature that matches what is
expected by ``func``.
"""
validator: ActionPermissionValidator = AllowAllValidator()
validator: RequestPermissionValidator = AllowAllValidator()
"""
``validator`` is an instance of `ActionPermissionValidator`. This is essentially a callable that
``validator`` is an instance of ``RequestPermissionValidator``. This is essentially a callable that
accepts `request` and `context` and returns a boolean to represent whether the permission is granted to perform
the action. The default validator will allow
the request. The default validator will allow
"""
# TODO: maybe this can be renamed to something like action selector?
# Because there are two ways it's used, to select from a list of action verbs, or to select a child object to which to
# forward the request.
class ActionManager(BaseModel):
class RequestManager(BaseModel):
"""
ActionManager is used by `SimComponent` instances to keep track of actions.
RequestManager is used by `SimComponent` instances to keep track of requests.
Its main purpose is to be a lookup from action name to action function and corresponding validation function. This
class is responsible for providing a consistent API for processing actions as well as helpful error messages.
Its main purpose is to be a lookup from request name to request function and corresponding validation function. This
class is responsible for providing a consistent API for processing requests as well as helpful error messages.
"""
actions: Dict[str, Action] = {}
"""maps action verb to an action object."""
request_types: Dict[str, RequestType] = {}
"""maps request name to an RequestType object."""
def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None:
"""
Process an action request.
Process an request request.
:param request: A list of strings which specify what action to take. The first string must be one of the allowed
actions, i.e. it must be a key of self.actions. The subsequent strings in the list are passed as parameters
to the action function.
:param request: A list of strings describing the request. The first string must be one of the allowed
request names, i.e. it must be a key of self.request_types. The subsequent strings in the list are passed as
parameters to the request function.
:type request: List[str]
:param context: Dictionary of additional information necessary to process or validate the request.
:type context: Dict
:raises RuntimeError: If the request parameter does not have a valid action identifier as the first item.
:raises RuntimeError: If the request parameter does not have a valid request name as the first item.
"""
action_key = request[0]
request_key = request[0]
if action_key not in self.actions:
if request_key not in self.request_types:
msg = (
f"Action request {request} could not be processed because {action_key} is not a valid action",
"within this ActionManager",
f"Request {request} could not be processed because {request_key} is not a valid request name",
"within this RequestManager",
)
_LOGGER.error(msg)
raise RuntimeError(msg)
action = self.actions[action_key]
action_options = request[1:]
request_type = self.request_types[request_key]
request_options = request[1:]
if not action.validator(action_options, context):
_LOGGER.debug(f"Action request {request} was denied due to insufficient permissions")
if not request_type.validator(request_options, context):
_LOGGER.debug(f"Request {request} was denied due to insufficient permissions")
return
action.func(action_options, context)
request_type.func(request_options, context)
def add_action(self, name: str, action: Action) -> None:
def add_request(self, name: str, request_type: RequestType) -> None:
"""
Add an action to this action manager.
Add a request type to this request manager.
:param name: The string associated to this action.
:param name: The string associated to this request.
:type name: str
:param action: Action object.
:type action: Action
:param request_type: Request type object which contains information about how to resolve request.
:type request_type: RequestType
"""
if name in self.actions:
msg = f"Attempted to register an action but the action name {name} is already taken."
if name in self.request_types:
msg = f"Attempted to register a request but the request name {name} is already taken."
_LOGGER.error(msg)
raise RuntimeError(msg)
self.actions[name] = action
self.request_types[name] = request_type
def remove_action(self, name: str) -> None:
def remove_request(self, name: str) -> None:
"""
Remove an action from this manager.
Remove a request from this manager.
:param name: name identifier of the action
:param name: name identifier of the request
:type name: str
"""
if name not in self.actions:
msg = f"Attempted to remove action {name} from action manager, but it was not registered."
if name not in self.request_types:
msg = f"Attempted to remove request {name} from request manager, but it was not registered."
_LOGGER.error(msg)
raise RuntimeError(msg)
self.actions.pop(name)
self.request_types.pop(name)
def get_action_tree(self) -> List[List[str]]:
"""Recursively generate action tree for this component."""
actions = []
for act_name, act in self.actions.items():
if isinstance(act.func, ActionManager):
sub_actions = act.func.get_action_tree()
sub_actions = [[act_name] + a for a in sub_actions]
actions.extend(sub_actions)
def get_request_types_recursively(self) -> List[List[str]]:
"""Recursively generate request tree for this component."""
requests = []
for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager):
sub_requests = req.func.get_request_types_recursively()
sub_requests = [[req_name] + a for a in sub_requests]
requests.extend(sub_requests)
else:
actions.append([act_name])
return actions
requests.append([req_name])
return requests
class SimComponent(BaseModel):
@@ -161,30 +158,30 @@ class SimComponent(BaseModel):
if not kwargs.get("uuid"):
kwargs["uuid"] = str(uuid4())
super().__init__(**kwargs)
self._action_manager: ActionManager = self._init_action_manager()
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
def _init_action_manager(self) -> ActionManager:
def _init_request_manager(self) -> RequestManager:
"""
Initialise the action manager for this component.
Initialise the request manager for this component.
When using a hierarchy of components, the child classes should call the parent class's _init_action_manager and
add additional actions on top of the existing generic ones.
When using a hierarchy of components, the child classes should call the parent class's _init_request_manager and
add additional requests on top of the existing generic ones.
Example usage for inherited classes:
..code::python
class WebBrowser(Application):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager() # all actions generic to any Application get initialised
am.add_action(...) # initialise any actions specific to the web browser
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager() # all requests generic to any Application get initialised
am.add_request(...) # initialise any requests specific to the web browser
return am
:return: Actiona manager object belonging to this sim component.
:rtype: ActionManager
:return: Request manager object belonging to this sim component.
:rtype: RequestManager
"""
return ActionManager()
return RequestManager()
@abstractmethod
def describe_state(self) -> Dict:
@@ -204,27 +201,27 @@ class SimComponent(BaseModel):
"""Update the visible statuses of the SimComponent."""
pass
def apply_action(self, action: List[str], context: Dict = {}) -> None:
def apply_request(self, request: List[str], context: Dict = {}) -> None:
"""
Apply an action to a simulation component. Action data is passed in as a 'namespaced' list of strings.
Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings.
If the list only has one element, the action is intended to be applied directly to this object. If the list has
multiple entries, the action is passed to the child of this object specified by the first one or two entries.
If the list only has one element, the request is intended to be applied directly to this object. If the list has
multiple entries, the request is passed to the child of this object specified by the first one or two entries.
This is essentially a namespace.
For example, ["turn_on",] is meant to apply an action of 'turn on' to this component.
For example, ["turn_on",] is meant to apply a request of 'turn on' to this component.
However, ["services", "email_client", "turn_on"] is meant to 'turn on' this component's email client service.
:param action: List describing the action to apply to this object.
:type action: List[str]
:param request: List describing the request to apply to this object.
:type request: List[str]
:param: context: Dict containing context for actions
:param: context: Dict containing context for requests
:type context: Dict
"""
if self._action_manager is None:
if self._request_manager is None:
return
self._action_manager(action, context)
self._request_manager(request, context)
def apply_timestep(self, timestep: int) -> None:
"""

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Dict, Final, List, Literal, Tuple
from primaite.simulator.core import Action, ActionManager, ActionPermissionValidator, SimComponent
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
from primaite.simulator.domain.account import Account, AccountType
@@ -43,10 +43,10 @@ class AccountGroup(Enum):
"For full access"
class GroupMembershipValidator(ActionPermissionValidator):
class GroupMembershipValidator(RequestPermissionValidator):
"""Permit actions based on group membership."""
allowed_groups:List[AccountGroup]
allowed_groups: List[AccountGroup]
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
@@ -79,14 +79,14 @@ class DomainController(SimComponent):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
# Action 'account' matches requests like:
# ['account', '<account-uuid>', *account_action]
am.add_action(
am.add_request(
"account",
Action(
func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context),
RequestType(
func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context),
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
),
)

View File

@@ -9,7 +9,7 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
from primaite.simulator.system.core.sys_log import SysLog
@@ -94,14 +94,14 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
self._folder_action_manager = ActionManager()
am.add_action("folder", Action(func=self._folder_action_manager))
self._folder_request_manager = RequestManager()
am.add_request("folder", RequestType(func=self._folder_request_manager))
self._file_action_manager = ActionManager()
am.add_action("file", Action(func=self._file_action_manager))
self._file_request_manager = RequestManager()
am.add_request("file", RequestType(func=self._file_request_manager))
return am
@@ -165,7 +165,7 @@ class FileSystem(SimComponent):
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self.sys_log.info(f"Created folder /{folder.name}")
self._folder_action_manager.add_action(folder.uuid, Action(func=folder._action_manager))
self._folder_request_manager.add_request(folder.uuid, RequestType(func=folder._request_manager))
return folder
def delete_folder(self, folder_name: str):
@@ -184,7 +184,7 @@ class FileSystem(SimComponent):
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
self._folder_action_manager.remove_action(folder.uuid)
self._folder_request_manager.remove_request(folder.uuid)
else:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
@@ -226,7 +226,7 @@ class FileSystem(SimComponent):
)
folder.add_file(file)
self.sys_log.info(f"Created file /{file.path}")
self._file_action_manager.add_action(file.uuid, Action(func=file._action_manager))
self._file_request_manager.add_request(file.uuid, RequestType(func=file._request_manager))
return file
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
@@ -240,7 +240,7 @@ class FileSystem(SimComponent):
folder = self.get_folder(folder_name)
if folder:
return folder.get_file(file_name)
self.fs.sys_log.info(f"file not found /{folder_name}/{file_name}")
self.sys_log.info(f"file not found /{folder_name}/{file_name}")
def delete_file(self, folder_name: str, file_name: str):
"""
@@ -254,7 +254,7 @@ class FileSystem(SimComponent):
file = folder.get_file(file_name)
if file:
folder.remove_file(file)
self._file_action_manager.remove_action(file.uuid)
self._file_request_manager.remove_request(file.uuid)
self.sys_log.info(f"Deleted file /{file.path}")
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
@@ -332,15 +332,15 @@ class Folder(FileSystemItemABC):
is_quarantined: bool = False
"Flag that marks the folder as quarantined if true."
def _init_action_manager(sekf) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("checkhash", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("repair", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("restore", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("delete", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("corrupt", Action(func=lambda request, context: ...)) # TODO implement action
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request
return am
@@ -509,15 +509,15 @@ class File(FileSystemItemABC):
with open(self.sim_path, mode="a"):
pass
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("checkhash", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("delete", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("repair", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("restore", Action(func=lambda request, context: ...)) # TODO implement action
am.add_action("corrupt", Action(func=lambda request, context: ...)) # TODO implement action
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("checkhash", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("delete", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("repair", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("restore", RequestType(func=lambda request, context: ...)) # TODO implement request
am.add_request("corrupt", RequestType(func=lambda request, context: ...)) # TODO implement request
return am

View File

@@ -6,7 +6,7 @@ from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
@@ -37,21 +37,18 @@ class Network(SimComponent):
Initialise the network.
Constructs the network and sets up its initial state including
the action manager and an empty MultiGraph for topology representation.
the request manager and an empty MultiGraph for topology representation.
"""
super().__init__(**kwargs)
self._nx_graph = MultiGraph()
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
self._node_action_manager = ActionManager()
am.add_action(
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
self._node_request_manager = RequestManager()
am.add_request(
"node",
Action(
func=self._node_action_manager
# func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
),
RequestType(func=self._node_request_manager),
)
return am
@@ -185,7 +182,7 @@ class Network(SimComponent):
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
self._node_action_manager.add_action(name=node.uuid, action=Action(func=node._action_manager))
self._node_request_manager.add_request(name=node.uuid, request_type=RequestType(func=node._request_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -219,7 +216,7 @@ class Network(SimComponent):
break
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_action_manager.remove_action(name=node.uuid)
self._node_request_manager.remove_request(name=node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""

View File

@@ -12,7 +12,7 @@ from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
@@ -144,11 +144,11 @@ class NIC(SimComponent):
)
return state
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_action("enable", Action(func=lambda request, context: self.enable()))
am.add_action("disable", Action(func=lambda request, context: self.disable()))
am.add_request("enable", RequestType(func=lambda request, context: self.enable()))
am.add_request("disable", RequestType(func=lambda request, context: self.disable()))
return am
@@ -502,7 +502,9 @@ class Link(SimComponent):
def _can_transmit(self, frame: Frame) -> bool:
if self.is_up:
frame_size_Mbits = frame.size_Mbits # noqa - Leaving it as Mbits as this is how they're expressed
return self.current_load + frame_size_Mbits <= self.bandwidth
# return self.current_load + frame_size_Mbits <= self.bandwidth
# TODO: re add this check once packet size limiting and MTU checks are implemented
return True
return False
def transmit_frame(self, sender_nic: Union[NIC, SwitchPort], frame: Frame) -> bool:
@@ -720,7 +722,9 @@ class ARPCache:
# Unmatched ARP Request
if arp_packet.target_ip_address != from_nic.ip_address:
self.sys_log.info(f"Ignoring ARP request for {arp_packet.target_ip_address}")
self.sys_log.info(
f"Ignoring ARP request for {arp_packet.target_ip_address}. Current IP address is {from_nic.ip_address}"
)
return
# Matched ARP request
@@ -942,35 +946,40 @@ class Node(SimComponent):
super().__init__(**kwargs)
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
self._install_system_software()
def _init_action_manager(self) -> ActionManager:
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
# need a better name and better documentation.
am = super()._init_action_manager()
# since there are potentially many services, create an action manager that can map service name
self._service_action_manager = ActionManager()
am.add_action("service", Action(func=self._service_action_manager))
self._nic_action_manager = ActionManager()
am.add_action("nic", Action(func=self._nic_action_manager))
am = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
self._service_request_manager = RequestManager()
am.add_request("service", RequestType(func=self._service_request_manager))
self._nic_request_manager = RequestManager()
am.add_request("nic", RequestType(func=self._nic_request_manager))
am.add_action("file_system", Action(func=self.file_system._action_manager))
am.add_request("file_system", RequestType(func=self.file_system._request_manager))
# currently we don't have any applications nor processes, so these will be empty
self._process_action_manager = ActionManager()
am.add_action("process", Action(func=self._process_action_manager))
self._application_action_manager = ActionManager()
am.add_action("application", Action(func=self._application_action_manager))
self._process_request_manager = RequestManager()
am.add_request("process", RequestType(func=self._process_request_manager))
self._application_request_manager = RequestManager()
am.add_request("application", RequestType(func=self._application_request_manager))
am.add_action("scan", Action(func=lambda request, context: ...)) # TODO implement OS scan
am.add_request("scan", RequestType(func=lambda request, context: ...)) # TODO implement OS scan
am.add_action("shutdown", Action(func=lambda request, context: self.power_off()))
am.add_action("startup", Action(func=lambda request, context: self.power_on()))
am.add_action("reset", Action(func=lambda request, context: ...)) # TODO implement node reset
am.add_action("logon", Action(func=lambda request, context: ...)) # TODO implement logon action
am.add_action("logoff", Action(func=lambda request, context: ...)) # TODO implement logoff action
am.add_request("shutdown", RequestType(func=lambda request, context: self.power_off()))
am.add_request("startup", RequestType(func=lambda request, context: self.power_on()))
am.add_request("reset", RequestType(func=lambda request, context: ...)) # TODO implement node reset
am.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request
am.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request
return am
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1064,7 +1073,7 @@ class Node(SimComponent):
self.sys_log.info(f"Connected NIC {nic}")
if self.operating_state == NodeOperatingState.ON:
nic.enable()
self._nic_action_manager.add_action(nic.uuid, Action(func=nic._action_manager))
self._nic_request_manager.add_request(nic.uuid, RequestType(func=nic._request_manager))
else:
msg = f"Cannot connect NIC {nic} as it is already connected"
self.sys_log.logger.error(msg)
@@ -1089,7 +1098,7 @@ class Node(SimComponent):
nic.parent = None
nic.disable()
self.sys_log.info(f"Disconnected NIC {nic}")
self._nic_action_manager.remove_action(nic.uuid)
self._nic_request_manager.remove_request(nic.uuid)
else:
msg = f"Cannot disconnect NIC {nic} as it is not connected"
self.sys_log.logger.error(msg)
@@ -1187,7 +1196,7 @@ class Node(SimComponent):
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
self._service_action_manager.add_action(service.uuid, Action(func=service._action_manager))
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
@@ -1203,7 +1212,7 @@ class Node(SimComponent):
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
self._service_action_manager.remove_action(service.uuid)
self._service_request_manager.remove_request(service.uuid)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):

View File

@@ -1,4 +1,7 @@
from primaite.simulator.network.hardware.base import NIC, Node
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
class Computer(Node):
@@ -36,3 +39,17 @@ class Computer(Node):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connect_nic(NIC(ip_address=kwargs["ip_address"], subnet_mask=kwargs["subnet_mask"]))
self._install_system_software()
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
# DNS Client
self.software_manager.install(DNSClient)
# FTP
self.software_manager.install(FTPClient)
# Web Browser
self.software_manager.install(WebBrowser)
super()._install_system_software()

View File

@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
@@ -43,7 +43,7 @@ class ACLRule(SimComponent):
def __str__(self) -> str:
rule_strings = []
for key, value in self.model_dump(exclude={"uuid", "action_manager"}).items():
for key, value in self.model_dump(exclude={"uuid", "request_manager"}).items():
if value is None:
value = "ANY"
if isinstance(value, Enum):
@@ -87,8 +87,8 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
# When the request reaches this action, it should now contain solely positional args for the 'add_rule' action.
# POSITIONAL ARGUMENTS:
@@ -99,9 +99,9 @@ class AccessControlList(SimComponent):
# 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2'))
# 5: destination port (str name of a Port (e.g. "HTTP"))
# 6: position (int)
am.add_action(
am.add_request(
"add_rule",
Action(
RequestType(
func=lambda request, context: self.add_rule(
ACLAction[request[0]],
IPProtocol[request[1]],
@@ -114,7 +114,7 @@ class AccessControlList(SimComponent):
),
)
am.add_action("remove_rule", Action(func=lambda request, context: self.remove_rule(int(request[0]))))
am.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0]))))
return am
def describe_state(self) -> Dict:
@@ -626,9 +626,9 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("acl", Action(func=self.acl._action_manager))
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request("acl", RequestType(func=self.acl._request_manager))
return am
def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]:

View File

@@ -9,10 +9,11 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.web_server.web_server import WebServer
def client_server_routed() -> Network:
@@ -135,9 +136,6 @@ def arcd_uc2_network() -> Network:
dns_server=IPv4Address("192.168.1.10"),
)
client_1.power_on()
client_1.software_manager.install(DNSClient)
client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa
client_1_dns_client_service.start()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
@@ -152,9 +150,6 @@ def arcd_uc2_network() -> Network:
dns_server=IPv4Address("192.168.1.10"),
)
client_2.power_on()
client_2.software_manager.install(DNSClient)
client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa
client_2_dns_client_service.start()
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller
@@ -191,24 +186,53 @@ def arcd_uc2_network() -> Network:
);"""
user_insert_statements = [
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');",
# noqa
"INSERT INTO user (name, email, age, city, occupation) "
"VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');",
# noqa
]
database_server.software_manager.install(DatabaseService)
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
database_service.start()
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
database_service._process_sql(ddl, None) # noqa
for insert_statement in user_insert_statements:
database_service._process_sql(insert_statement, None) # noqa
@@ -230,9 +254,10 @@ def arcd_uc2_network() -> Network:
database_client.run()
database_client.connect()
web_server.software_manager.install(WebServer)
# register the web_server to a domain
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa
dns_server_service.start()
dns_server_service.dns_register("arcd.com", web_server.ip_address)
# Backup Server
@@ -244,6 +269,7 @@ def arcd_uc2_network() -> Network:
dns_server=IPv4Address("192.168.1.10"),
)
backup_server.power_on()
backup_server.software_manager.install(FTPServer)
network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])
# Security Suite
@@ -271,4 +297,10 @@ def arcd_uc2_network() -> Network:
# Allow DNS requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
# Allow FTP requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2)
# Open port 80 for web server
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
return network

View File

@@ -5,6 +5,8 @@ from typing import Optional
from pydantic import BaseModel
from primaite.simulator.network.protocols.packet import DataPacket
class ARPEntry(BaseModel):
"""
@@ -18,7 +20,7 @@ class ARPEntry(BaseModel):
nic_uuid: str
class ARPPacket(BaseModel):
class ARPPacket(DataPacket):
"""
Represents the ARP layer of a network frame.

View File

@@ -5,6 +5,8 @@ from typing import Optional
from pydantic import BaseModel
from primaite.simulator.network.protocols.packet import DataPacket
class DNSRequest(BaseModel):
"""Represents a DNS Request packet of a network frame.
@@ -26,7 +28,7 @@ class DNSReply(BaseModel):
"IP Address of the Domain Name requested."
class DNSPacket(BaseModel):
class DNSPacket(DataPacket):
"""
Represents the DNS layer of a network frame.

View File

@@ -0,0 +1,55 @@
from enum import Enum
from typing import Any, Optional, Union
from primaite.simulator.network.protocols.packet import DataPacket
class FTPCommand(Enum):
"""FTP Commands that are allowed."""
PORT = "PORT"
"""Set a port to be used for the FTP transfer."""
STOR = "STOR"
"""Copy or put data to the FTP server."""
RETR = "RETR"
"""Retrieve data from the FTP server."""
DELE = "DELE"
"""Delete the file in the specified path."""
RMD = "RMD"
"""Remove the directory in the specified path."""
MKD = "MKD"
"""Make a directory in the specified path."""
LIST = "LIST"
"""Return a list of files in the specified path."""
QUIT = "QUIT"
"""Ends connection between client and server."""
class FTPStatusCode(Enum):
"""Status code of the current FTP request."""
OK = 200
"""Command successful."""
ERROR = 500
"""General error code."""
class FTPPacket(DataPacket):
"""Represents an FTP Packet."""
ftp_command: FTPCommand
"""Command type of the packet."""
ftp_command_args: Optional[Any] = None
"""Arguments for command."""
status_code: Union[FTPStatusCode, None] = None
"""Status of the response."""

View File

@@ -0,0 +1,64 @@
from enum import Enum
from primaite.simulator.network.protocols.packet import DataPacket
class HttpRequestMethod(Enum):
"""Enum list of HTTP Request methods that can be handled by the simulation."""
GET = "GET"
"""HTTP GET Method. Requests using GET should only retrieve data."""
HEAD = "HEAD"
"""Asks for a response identical to a GET request, but without the response body."""
POST = "POST"
"""Submit an entity to the specified resource, often causing a change in state or side effects on the server."""
PUT = "PUT"
"""Replace all current representations of the target resource with the request payload."""
DELETE = "DELETE"
"""Delete the specified resource."""
PATCH = "PATCH"
"""Apply partial modifications to a resource."""
class HttpStatusCode(Enum):
"""List of available HTTP Statuses."""
OK = 200
"""request has succeeded."""
BAD_REQUEST = 400
"""Payload cannot be parsed."""
UNAUTHORIZED = 401
"""Auth required."""
NOT_FOUND = 404
"""Item not found in server."""
METHOD_NOT_ALLOWED = 405
"""Method is not supported by server."""
INTERNAL_SERVER_ERROR = 500
"""Error on the server side."""
class HttpRequestPacket(DataPacket):
"""Class that represents an HTTP Request Packet."""
request_method: HttpRequestMethod
"""The HTTP Request method."""
request_url: str
"""URL of request."""
class HttpResponsePacket(DataPacket):
"""Class that reprensents an HTTP Response Packet."""
status_code: HttpStatusCode = None
"""Status code of the HTTP response."""

View File

@@ -0,0 +1,17 @@
from typing import Any
from pydantic import BaseModel
class DataPacket(BaseModel):
"""Data packet abstract class."""
payload: Any = None
"""Payload content of the packet."""
packet_payload_size: float = 0
"""Size of the packet."""
def get_packet_size(self) -> float:
"""Returns the size of the packet header and payload."""
return self.packet_payload_size + float(len(self.model_dump_json().encode("utf-8")))

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel
from primaite import getLogger
from primaite.simulator.network.protocols.arp import ARPPacket
from primaite.simulator.network.protocols.packet import DataPacket
from primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
@@ -132,6 +133,10 @@ class Frame(BaseModel):
@property
def size(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
"""The size of the Frame in Bytes."""
# get the payload size if it is a data packet
if isinstance(self.payload, DataPacket):
return self.payload.get_packet_size()
return float(len(self.model_dump_json().encode("utf-8")))
@property

View File

@@ -1,6 +1,6 @@
from typing import Dict
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.controller import DomainController
from primaite.simulator.network.container import Network
@@ -21,12 +21,12 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# pass through network actions to the network objects
am.add_action("network", Action(func=self.network._action_manager))
# pass through domain actions to the domain object
am.add_action("domain", Action(func=self.domain._action_manager))
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
# pass through network requests to the network objects
am.add_request("network", RequestType(func=self.network._request_manager))
# pass through domain requests to the domain object
am.add_request("domain", RequestType(func=self.domain._request_manager))
return am
def describe_state(self) -> Dict:

View File

@@ -81,18 +81,6 @@ class Application(IOSoftware):
"""
pass
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:return: True if successful, False otherwise.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.

View File

@@ -49,7 +49,7 @@ class DatabaseClient(Application):
"""
self.server_ip_address = server_ip_address
self.server_password = server_password
self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {server_password=}.")
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self) -> bool:
"""Connect to a Database Service."""
@@ -60,13 +60,25 @@ class DatabaseClient(Application):
def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
) -> bool:
"""
Connects the DatabaseClient to the DatabaseServer.
:param: server_ip_address: IP address of the database server
:type: server_ip_address: IPv4Address
:param: password: Password used to connect to the database server. Optional.
:type: password: Optional[str]
:param: is_reattempt: True if the connect request has been reattempted. Default False
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
if self.connected:
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised")
self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} authorised")
self.server_ip_address = server_ip_address
return self.connected
else:
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined")
self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} declined")
return False
payload = {"type": "connect_request", "password": password}
software_manager: SoftwareManager = self.software_manager
@@ -83,15 +95,29 @@ class DatabaseClient(Application):
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
)
self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}")
self.sys_log.info(f"{self.name}: DatabaseClient disconnected from {self.server_ip_address}")
self.server_ip_address = None
self.connected = False
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
"""
Send a query to the connected database server.
:param: sql: SQL query to send to the database server.
:type: sql: str
:param: query_id: ID of the query, used as reference
:type: query_id: str
:param: is_reattempt: True if the query request has been reattempted. Default False
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
success = self._query_success_tracker.get(query_id)
if success:
self.sys_log.info(f"{self.name}: Query successful {sql}")
return True
self.sys_log.info(f"{self.name}: Unable to run query {sql}")
return False
else:
software_manager: SoftwareManager = self.software_manager

View File

@@ -1,7 +1,12 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Dict, Optional
from urllib.parse import urlparse
from primaite.simulator.network.protocols.http import HttpRequestMethod, HttpRequestPacket, HttpResponsePacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.dns.dns_client import DNSClient
class WebBrowser(Application):
@@ -11,12 +16,29 @@ class WebBrowser(Application):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
domain_name: str
"The domain name of the webpage."
domain_name_ip_address: Optional[IPv4Address]
domain_name_ip_address: Optional[IPv4Address] = None
"The IP address of the domain name for the webpage."
history: Dict[str]
"A dict that stores all of the previous domain names."
latest_response: Optional[HttpResponsePacket] = None
"""Keeps track of the latest HTTP response."""
def __init__(self, **kwargs):
kwargs["name"] = "WebBrowser"
kwargs["protocol"] = IPProtocol.TCP
# default for web is port 80
if kwargs.get("port") is None:
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.run()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of the WebBrowser.
:return: A dictionary capturing the current state of the WebBrowser and its child objects.
"""
return super().describe_state()
def reset_component_for_episode(self, episode: int):
"""
@@ -25,30 +47,90 @@ class WebBrowser(Application):
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
self.domain_name = ""
self.domain_name_ip_address = None
self.history = {}
self.latest_response = None
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
def get_webpage(self, url: str) -> bool:
"""
Retrieve the webpage.
This should send a request to the web server which also requests for a list of users
:param: url: The address of the web page the browser requests
:type: url: str
"""
# reset latest response
self.latest_response = None
try:
parsed_url = urlparse(url)
except Exception:
self.sys_log.error(f"{url} is not a valid URL")
return False
# get the IP address of the domain name via DNS
dns_client: DNSClient = self.software_manager.software["DNSClient"]
domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname)
# if domain does not exist, the request fails
if domain_exists:
# set current domain name IP address
self.domain_name_ip_address = dns_client.dns_cache[parsed_url.hostname]
else:
# check if url is an ip address
try:
self.domain_name_ip_address = IPv4Address(parsed_url.hostname)
except Exception:
# unable to deal with this request
self.sys_log.error(f"{self.name}: Unable to resolve URL {url}")
return False
# create HTTPRequest payload
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url=url)
# send request
return self.send(
payload=payload,
dest_ip_address=self.domain_name_ip_address,
dest_port=parsed_url.port if parsed_url.port else Port.HTTP,
)
def send(
self,
payload: HttpRequestPacket,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = Port.HTTP,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:param payload: The payload to send.
:return: True if successful, False otherwise.
"""
pass
self.sys_log.info(f"{self.name}: Sending HTTP {payload.request_method.name} {payload.request_url}")
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs
)
def receive(self, payload: HttpResponsePacket, session_id: Optional[str] = None, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
pass
if not isinstance(payload, HttpResponsePacket):
self.sys_log.error(f"{self.name} received a packet that is not an HttpResponsePacket")
return False
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
self.latest_response = payload
return True

View File

@@ -193,7 +193,7 @@ class SessionManager:
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
outbound_nic.send_frame(frame)
return outbound_nic.send_frame(frame)
def receive_frame(self, frame: Frame):
"""

View File

@@ -110,7 +110,7 @@ class SoftwareManager:
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
):
) -> bool:
"""
Send a payload to the SessionManager.
@@ -119,7 +119,7 @@ class SoftwareManager:
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
"""
self.session_manager.receive_payload_from_software_manager(
return self.session_manager.receive_payload_from_software_manager(
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
)

View File

@@ -1,5 +1,6 @@
import sqlite3
from datetime import datetime
from ipaddress import IPv4Address
from sqlite3 import OperationalError
from typing import Any, Dict, List, Optional, Union
@@ -9,6 +10,7 @@ from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
@@ -23,6 +25,15 @@ class DatabaseService(Service):
password: Optional[str] = None
connections: Dict[str, datetime] = {}
backup_server: IPv4Address = None
"""IP address of the backup server."""
latest_backup_directory: str = None
"""Directory of latest backup."""
latest_backup_file_name: str = None
"""File name of latest backup."""
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseService"
kwargs["port"] = Port.POSTGRES_SERVER
@@ -30,6 +41,9 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
self._connect()
def _connect(self):
self._conn = sqlite3.connect(self._db_file.sim_path)
self._cursor = self._conn.cursor()
@@ -40,8 +54,10 @@ class DatabaseService(Service):
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql)
return [row[0] for row in results["data"]]
results = self._process_sql(sql, None)
if isinstance(results["data"], dict):
return list(results["data"].keys())
return []
def show(self, markdown: bool = False):
"""
@@ -58,6 +74,72 @@ class DatabaseService(Service):
table.add_row([row])
print(table)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.
:param: backup_server_ip: The IP address of the backup server
"""
self.backup_server = backup_server
def backup_database(self) -> bool:
"""Create a backup of the database to the configured backup server."""
# check if the backup server was configured
if self.backup_server is None:
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
self._conn.close()
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
# send backup copy of database file to FTP server
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server,
src_file_name=self._db_file.name,
src_folder_name=self._db_file.folder.name,
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
real_file_path=self._db_file.sim_path,
)
self._connect()
if response:
return True
self.sys_log.error("Unable to create database backup.")
return False
def restore_backup(self) -> bool:
"""Restore a backup from backup server."""
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
# retrieve backup file from backup server
response = ftp_client_service.request_file(
src_folder_name=str(self.uuid),
src_file_name="database.db",
dest_folder_name="downloads",
dest_file_name="database.db",
dest_ip_address=self.backup_server,
)
if response:
self._conn.close()
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.move_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self._connect()
return self._db_file is not None
self.sys_log.error("Unable to restore database backup.")
return False
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
@@ -73,10 +155,10 @@ class DatabaseService(Service):
if self.password == password:
status_code = 200 # ok
self.connections[session_id] = datetime.now()
self.sys_log.info(f"Connect request for {session_id=} authorised")
self.sys_log.info(f"{self.name}: Connect request for {session_id=} authorised")
else:
status_code = 401 # Unauthorised
self.sys_log.info(f"Connect request for {session_id=} declined")
self.sys_log.info(f"{self.name}: Connect request for {session_id=} declined")
else:
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}

View File

@@ -27,6 +27,7 @@ class DNSClient(Service):
# TCP for now
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def describe_state(self) -> Dict:
"""
@@ -71,19 +72,25 @@ class DNSClient(Service):
:param: session_id: The Session ID the payload is to originate from. Optional.
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
"""
# check if DNS server is configured
if self.dns_server is None:
self.sys_log.error(f"{self.name}: DNS Server is not configured")
return False
# check if the target domain is in the client's DNS cache
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
# check if the domain is already in the DNS cache
if target_domain in self.dns_cache:
self.sys_log.info(
f"DNS Client: Domain lookup for {target_domain} successful, resolves to {self.dns_cache[target_domain]}"
f"{self.name}: Domain lookup for {target_domain} successful,"
f"resolves to {self.dns_cache[target_domain]}"
)
return True
else:
# return False if already reattempted
if is_reattempt:
self.sys_log.info(f"DNS Client: Domain lookup for {target_domain} failed")
self.sys_log.info(f"{self.name}: Domain lookup for {target_domain} failed")
return False
else:
# send a request to check if domain name exists in the DNS Server
@@ -103,14 +110,13 @@ class DNSClient(Service):
self,
payload: DNSPacket,
session_id: Optional[str] = None,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
@@ -118,10 +124,11 @@ class DNSClient(Service):
:return: True if successful, False otherwise.
"""
# create DNS request packet
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return True
self.sys_log.info(f"{self.name}: Sending DNS request to resolve {payload.dns_request.domain_name_request}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs
)
def receive(
self,
@@ -132,9 +139,6 @@ class DNSClient(Service):
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
@@ -143,12 +147,16 @@ class DNSClient(Service):
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_reply is not None:
# add the IP address to the client cache
if payload.dns_reply.domain_name_ip_address:
self.sys_log.info(
f"{self.name}: Resolved domain name {payload.dns_request.domain_name_request} "
f"to {payload.dns_reply.domain_name_ip_address}"
)
self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address
return True
self.sys_log.error(f"Failed to resolve domain name {payload.dns_request.domain_name_request}")
return False

View File

@@ -26,6 +26,7 @@ class DNSServer(Service):
# TCP for now
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def describe_state(self) -> Dict:
"""
@@ -95,13 +96,13 @@ class DNSServer(Service):
payload: DNSPacket = payload
if payload.dns_request is not None:
self.sys_log.info(
f"DNS Server: Received domain lookup request for {payload.dns_request.domain_name_request} "
f"{self.name}: Received domain lookup request for {payload.dns_request.domain_name_request} "
f"from session {session_id}"
)
# generate a reply with the correct DNS IP address
payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
self.sys_log.info(
f"DNS Server: Responding to domain lookup request for {payload.dns_request.domain_name_request} "
f"{self.name}: Responding to domain lookup request for {payload.dns_request.domain_name_request} "
f"with ip address: {payload.dns_reply.domain_name_ip_address}"
)
# send reply

View File

@@ -0,0 +1,277 @@
from ipaddress import IPv4Address
from typing import Optional
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
class FTPClient(FTPServiceABC):
"""
A class for simulating an FTP client service.
This class inherits from the `Service` class and provides methods to emulate FTP
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
connected: bool = False
"""Keeps track of whether or not the FTP client is connected to an FTP server."""
def __init__(self, **kwargs):
kwargs["name"] = "FTPClient"
kwargs["port"] = Port.FTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
:param: payload: The FTP Packet to process
:type: payload: FTPPacket
:param: session_id: session ID linked to the FTP Packet. Optional.
:type: session_id: Optional[str]
"""
# if client service is down, return error
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.error("FTP Client is not running")
payload.status_code = FTPStatusCode.ERROR
return payload
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
# process client specific commands, otherwise call super
return super()._process_ftp_command(payload=payload, session_id=session_id, **kwargs)
def _connect_to_server(
self,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = Port.FTP,
session_id: Optional[str] = None,
is_reattempt: Optional[bool] = False,
) -> bool:
"""
Connects the client to a given FTP server.
:param: dest_ip_address: IP address of the FTP server the client needs to connect to. Optional.
:type: dest_ip_address: Optional[IPv4Address]
:param: dest_port: Port of the FTP server the client needs to connect to. Optional.
:type: dest_port: Optional[Port]
:param: is_reattempt: Set to True if attempt to connect to FTP Server has been attempted. Default False.
:type: is_reattempt: Optional[bool]
"""
# make sure the service is running before attempting
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.error(f"FTPClient not running for {self.sys_log.hostname}")
return False
# normally FTP will choose a random port for the transfer, but using the FTP command port will do for now
# create FTP packet
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
)
if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id):
if payload.status_code == FTPStatusCode.OK:
self.sys_log.info(
f"{self.name}: Successfully connected to FTP Server "
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
)
return True
else:
if is_reattempt:
# reattempt failed
self.sys_log.info(
f"{self.name}: Unable to connect to FTP Server "
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
)
return False
else:
# try again
self._connect_to_server(
dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, is_reattempt=True
)
else:
self.sys_log.error(f"{self.name}: Unable to send FTPPacket")
return False
def _disconnect_from_server(
self, dest_ip_address: Optional[IPv4Address] = None, dest_port: Optional[Port] = Port.FTP
) -> bool:
"""
Connects the client from a given FTP server.
:param: dest_ip_address: IP address of the FTP server the client needs to disconnect from. Optional.
:type: dest_ip_address: Optional[IPv4Address]
:param: dest_port: Port of the FTP server the client needs to disconnect from. Optional.
:type: dest_port: Optional[Port]
:param: is_reattempt: Set to True if attempt to disconnect from FTP Server has been attempted. Default False.
:type: is_reattempt: Optional[bool]
"""
# send a disconnect request payload to FTP server
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.QUIT)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
)
if payload.status_code == FTPStatusCode.OK:
self.connected = False
return True
return False
def send_file(
self,
dest_ip_address: IPv4Address,
src_folder_name: str,
src_file_name: str,
dest_folder_name: str,
dest_file_name: str,
dest_port: Optional[Port] = Port.FTP,
session_id: Optional[str] = None,
real_file_path: Optional[str] = None,
) -> bool:
"""
Send a file to a target IP address.
The function checks if the file exists in the FTP Client host.
The STOR command is then sent to the FTP Server.
:param: dest_ip_address: The IP address of the machine that hosts the FTP Server.
:type: dest_ip_address: IPv4Address
:param: src_folder_name: The name of the folder that contains the file to send to the FTP Server.
:type: src_folder_name: str
:param: src_file_name: The name of the file to send to the FTP Server.
:type: src_file_name: str
:param: dest_folder_name: The name of the folder where the file will be stored in the FTP Server.
:type: dest_folder_name: str
:param: dest_file_name: The name of the file to be saved on the FTP Server.
:type: dest_file_name: str
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP.
:type: dest_port: Optional[Port]
:param: session_id: The id of the session
:type: session_id: Optional[str]
"""
# check if the file to transfer exists on the client
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
if not file_to_transfer:
self.sys_log.error(f"Unable to send file that does not exist: {src_folder_name}/{src_file_name}")
return False
# check if FTP is currently connected to IP
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
if not self.connected:
return False
else:
self.sys_log.info(f"Sending file {src_folder_name}/{src_file_name} to {str(dest_ip_address)}")
# send STOR request
if self._send_data(
file=file_to_transfer,
dest_folder_name=dest_folder_name,
dest_file_name=dest_file_name,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
):
return self._disconnect_from_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
return False
def request_file(
self,
dest_ip_address: IPv4Address,
src_folder_name: str,
src_file_name: str,
dest_folder_name: str,
dest_file_name: str,
dest_port: Optional[Port] = Port.FTP,
) -> bool:
"""
Request a file from a target IP address.
Sends a RETR command to the FTP Server.
:param: dest_ip_address: The IP address of the machine that hosts the FTP Server.
:type: dest_ip_address: IPv4Address
:param: src_folder_name: The name of the folder that contains the file to send to the FTP Server.
:type: src_folder_name: str
:param: src_file_name: The name of the file to send to the FTP Server.
:type: src_file_name: str
:param: dest_folder_name: The name of the folder where the file will be stored in the FTP Server.
:type: dest_folder_name: str
:param: dest_file_name: The name of the file to be saved on the FTP Server.
:type: dest_file_name: str
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP.
:type: dest_port: Optional[Port]
"""
# check if FTP is currently connected to IP
self.connected = self._connect_to_server(dest_ip_address=dest_ip_address, dest_port=dest_port)
if not self.connected:
return False
else:
# send retrieve request
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.RETR,
ftp_command_args={
"src_folder_name": src_folder_name,
"src_file_name": src_file_name,
"dest_file_name": dest_file_name,
"dest_folder_name": dest_folder_name,
},
)
self.sys_log.info(f"Requesting file {src_folder_name}/{src_file_name} from {str(dest_ip_address)}")
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port
)
# the payload should have ok status code
if payload.status_code == FTPStatusCode.OK:
self.sys_log.info(f"{self.name}: File {src_folder_name}/{src_file_name} found in FTP server.")
return True
else:
self.sys_log.error(f"{self.name}: File {src_folder_name}/{src_file_name} does not exist in FTP server")
return False
def receive(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
:param: payload: FTPPacket payload.
:type: payload: FTPPacket
:param: session_id: ID of the session. Optional.
:type: session_id: Optional[str]
"""
if not isinstance(payload, FTPPacket):
self.sys_log.error(f"{payload} is not an FTP packet")
return False
"""
Ignore ftp payload if status code is None.
This helps prevent an FTP request loop - FTP client and servers can exist on
the same node.
"""
if payload.status_code is None:
return False
self._process_ftp_command(payload=payload, session_id=session_id)
return True

View File

@@ -0,0 +1,93 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
class FTPServer(FTPServiceABC):
"""
A class for simulating an FTP server service.
This class inherits from the `Service` class and provides methods to emulate FTP
RFC 959: https://datatracker.ietf.org/doc/html/rfc959
"""
server_password: Optional[str] = None
"""Password needed to connect to FTP server. Default is None."""
connections: Dict[str, IPv4Address] = {}
"""Current active connections to the FTP server."""
def __init__(self, **kwargs):
kwargs["name"] = "FTPServer"
kwargs["port"] = Port.FTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.start()
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
:param: payload: The FTP Packet to process
:type: payload: FTPPacket
:param: session_id: session ID linked to the FTP Packet. Optional.
:type: session_id: Optional[str]
"""
# error code by default
payload.status_code = FTPStatusCode.ERROR
# if server service is down, return error
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.error("FTP Server not running")
return payload
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
if session_id:
session_details = self._get_session_details(session_id)
if payload.ftp_command is not None:
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
# process server specific commands, otherwise call super
if payload.ftp_command == FTPCommand.PORT:
# check that the port is valid
if isinstance(payload.ftp_command_args, Port) and payload.ftp_command_args.value in range(0, 65535):
# return successful connection
self.connections[session_id] = session_details.with_ip_address
payload.status_code = FTPStatusCode.OK
return payload
self.sys_log.error(f"Invalid Port {payload.ftp_command_args}")
return payload
if payload.ftp_command == FTPCommand.QUIT:
self.connections.pop(session_id)
payload.status_code = FTPStatusCode.OK
return payload
return super()._process_ftp_command(payload=payload, session_id=session_id, **kwargs)
def receive(self, payload: Any, session_id: Optional[str] = None, **kwargs) -> bool:
"""Receives a payload from the SessionManager."""
if not isinstance(payload, FTPPacket):
self.sys_log.error(f"{payload} is not an FTP packet")
return False
"""
Ignore ftp payload if status code is defined.
This means that an FTP server has already handled the packet and
prevents an FTP request loop - FTP client and servers can exist on
the same node.
"""
if payload.status_code is not None:
return False
self.send(self._process_ftp_command(payload=payload, session_id=session_id), session_id)
return True

View File

@@ -0,0 +1,184 @@
import shutil
from abc import ABC
from ipaddress import IPv4Address
from typing import Optional
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
class FTPServiceABC(Service, ABC):
"""
Abstract Base Class for FTP Client and Service.
Contains shared methods between both classes.
"""
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
:param: payload: The FTP Packet to process
:type: payload: FTPPacket
:param: session_id: session ID linked to the FTP Packet. Optional.
:type: session_id: Optional[str]
"""
if payload.ftp_command is not None:
self.sys_log.info(f"Received FTP {payload.ftp_command.name} command.")
# handle STOR request
if payload.ftp_command == FTPCommand.STOR:
# check that the file is created in the computed hosting the FTP server
if self._store_data(payload=payload):
payload.status_code = FTPStatusCode.OK
if payload.ftp_command == FTPCommand.RETR:
if self._retrieve_data(payload=payload, session_id=session_id):
payload.status_code = FTPStatusCode.OK
return payload
def _store_data(self, payload: FTPPacket) -> bool:
"""
Stores the data in the FTP Service's host machine.
:param: payload: The FTP Packet that contains the file data
:type: FTPPacket
"""
try:
file_name = payload.ftp_command_args["dest_file_name"]
folder_name = payload.ftp_command_args["dest_folder_name"]
file_size = payload.ftp_command_args["file_size"]
real_file_path = payload.ftp_command_args.get("real_file_path")
is_real = real_file_path is not None
file = self.file_system.create_file(
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
)
self.sys_log.info(
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
f"{payload.ftp_command_args['dest_file_name']}"
)
if is_real:
shutil.copy(real_file_path, file.sim_path)
# file should exist
return self.file_system.get_file(file_name=file_name, folder_name=folder_name) is not None
except Exception as e:
self.sys_log.error(f"Unable to create file in {self.sys_log.hostname}: {e}")
return False
def _send_data(
self,
file: File,
dest_folder_name: str,
dest_file_name: str,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_response: bool = False,
) -> bool:
"""
Sends data from the host FTP Service's machine to another FTP Service's host machine.
:param: file: File to send to the target FTP Service.
:type: file: File
:param: dest_folder_name: The name of the folder where the file will be stored in the FTP Server.
:type: dest_folder_name: str
:param: dest_file_name: The name of the file to be saved on the FTP Server.
:type: dest_file_name: str
:param: dest_ip_address: The IP address of the machine that hosts the FTP Server.
:type: dest_ip_address: Optional[IPv4Address]
:param: dest_port: The open port of the machine that hosts the FTP Server. Default is Port.FTP.
:type: dest_port: Optional[Port]
:param: session_id: session ID linked to the FTP Packet. Optional.
:type: session_id: Optional[str]
:param: is_response: is true if the data being sent is in response to a request. Default False.
:type: is_response: bool
"""
# send STOR request
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR,
ftp_command_args={
"dest_folder_name": dest_folder_name,
"dest_file_name": dest_file_name,
"file_size": file.sim_size,
"real_file_path": file.sim_path if file.real else None,
},
packet_payload_size=file.sim_size,
status_code=FTPStatusCode.OK if is_response else None,
)
self.sys_log.info(f"{self.name}: Sending file {file.folder.name}/{file.name}")
response = self.send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
)
if response and payload.status_code == FTPStatusCode.OK:
return True
return False
def _retrieve_data(self, payload: FTPPacket, session_id: Optional[str] = None) -> bool:
"""
Handle the transfer of data from Server to Client.
:param: payload: The FTP Packet that contains the file data
:type: FTPPacket
"""
try:
# find the file
file_name = payload.ftp_command_args["src_file_name"]
folder_name = payload.ftp_command_args["src_folder_name"]
dest_folder_name = payload.ftp_command_args["dest_folder_name"]
dest_file_name = payload.ftp_command_args["dest_file_name"]
retrieved_file: File = self.file_system.get_file(folder_name=folder_name, file_name=file_name)
# if file does not exist, return an error
if not retrieved_file:
self.sys_log.error(
f"File {payload.ftp_command_args['dest_folder_name']}/"
f"{payload.ftp_command_args['dest_file_name']} does not exist in {self.sys_log.hostname}"
)
return False
else:
# send requested data
return self._send_data(
file=retrieved_file,
dest_file_name=dest_file_name,
dest_folder_name=dest_folder_name,
session_id=session_id,
is_response=True,
)
except Exception as e:
self.sys_log.error(f"Unable to retrieve file from {self.sys_log.hostname}: {e}")
return False
def send(
self,
payload: FTPPacket,
session_id: Optional[str] = None,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
self.sys_log.info(f"{self.name}: Sending FTP {payload.ftp_command.name} {payload.ftp_command_args}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs
)

View File

@@ -33,12 +33,14 @@ class DataManipulationBot(DatabaseClient):
self.server_ip_address = server_ip_address
self.payload = payload
self.server_password = server_password
self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}.")
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
)
def run(self):
"""Run the DataManipulationBot."""
if self.server_ip_address and self.payload:
self.sys_log.info(f"Attempting to start the {self.name}")
self.sys_log.info(f"{self.name}: Attempting to start the {self.name}")
super().run()
if not self.connected:
self.connect()
@@ -46,4 +48,4 @@ class DataManipulationBot(DatabaseClient):
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
else:
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_io_address and payload.")
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_ip_address and payload.")

View File

@@ -2,7 +2,7 @@ from enum import Enum
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.software import IOSoftware
_LOGGER = getLogger(__name__)
@@ -43,16 +43,16 @@ class Service(IOSoftware):
_restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action("scan", Action(func=lambda request, context: self.scan()))
am.add_action("stop", Action(func=lambda request, context: self.stop()))
am.add_action("start", Action(func=lambda request, context: self.start()))
am.add_action("pause", Action(func=lambda request, context: self.pause()))
am.add_action("resume", Action(func=lambda request, context: self.resume()))
am.add_action("restart", Action(func=lambda request, context: self.restart()))
am.add_action("disable", Action(func=lambda request, context: self.disable()))
am.add_action("enable", Action(func=lambda request, context: self.enable()))
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_action("scan", RequestType(func=lambda request, context: self.scan()))
am.add_request("stop", RequestType(func=lambda request, context: self.stop()))
am.add_request("start", RequestType(func=lambda request, context: self.start()))
am.add_request("pause", RequestType(func=lambda request, context: self.pause()))
am.add_request("resume", RequestType(func=lambda request, context: self.resume()))
am.add_request("restart", RequestType(func=lambda request, context: self.restart()))
am.add_request("disable", RequestType(func=lambda request, context: self.disable()))
am.add_request("enable", RequestType(func=lambda request, context: self.enable()))
return am
def describe_state(self) -> Dict:

View File

@@ -0,0 +1,145 @@
from ipaddress import IPv4Address
from typing import Any, Optional
from urllib.parse import urlparse
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
HttpRequestPacket,
HttpResponsePacket,
HttpStatusCode,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.service import Service
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
def __init__(self, **kwargs):
kwargs["name"] = "WebServer"
kwargs["protocol"] = IPProtocol.TCP
# default for web is port 80
if kwargs.get("port") is None:
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self._install_web_files()
self.start()
def _install_web_files(self):
"""
Installs the files hosted by the web service.
This is usually HTML, CSS, JS or PHP files requested by browsers to display the webpage.
"""
# index HTML main file
self.file_system.create_file(file_name="index.html", folder_name="primaite")
def _process_http_request(self, payload: HttpRequestPacket, session_id: Optional[str] = None) -> bool:
"""
Parse the HttpRequestPacket.
:param: payload: Payload containing th HttpRequestPacket
:type: payload: HttpRequestPacket
:param: session_id: Session id of the http request
:type: session_id: Optional[str]
"""
response = HttpResponsePacket()
self.sys_log.info(f"{self.name}: Received HTTP {payload.request_method.name} {payload.request_url}")
# check the type of HTTP request
if payload.request_method == HttpRequestMethod.GET:
response = self._handle_get_request(payload=payload)
elif payload.request_method == HttpRequestMethod.POST:
pass
else:
# send a method not allowed response
response.status_code = HttpStatusCode.METHOD_NOT_ALLOWED
# send response to web client
self.send(payload=response, session_id=session_id)
# return true if response is OK
return response.status_code == HttpStatusCode.OK
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:
"""
Handle a GET HTTP request.
:param: payload: HTTP request payload
:type: payload: HttpRequestPacket
"""
response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload)
try:
parsed_url = urlparse(payload.request_url)
path = parsed_url.path.strip("/")
if len(path) < 1:
# query succeeded
response.status_code = HttpStatusCode.OK
if path.startswith("users"):
# get data from DatabaseServer
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
# get all users
if db_client.query("SELECT * FROM user;"):
# query succeeded
response.status_code = HttpStatusCode.OK
return response
except Exception:
# something went wrong on the server
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
return response
def send(
self,
payload: HttpResponsePacket,
session_id: Optional[str] = None,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param: payload: The payload to send.
:param: session_id: The id of the session
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:return: True if successful, False otherwise.
"""
self.sys_log.info(f"{self.name}: Sending HTTP Response {payload.status_code}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, **kwargs
)
def receive(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
:param: payload: The payload to send.
:param: session_id: The id of the session. Optional.
"""
# check if the payload is an HTTPPacket
if not isinstance(payload, HttpRequestPacket):
self.sys_log.error("Payload is not an HTTPPacket")
return False
return self._process_http_request(payload=payload, session_id=session_id)

View File

@@ -1,10 +1,12 @@
from abc import abstractmethod
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
@@ -85,17 +87,25 @@ class Software(SimComponent):
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
am.add_action(
def _init_request_manager(self) -> RequestManager:
am = super()._init_request_manager()
am.add_request(
"compromise",
Action(
RequestType(
func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED),
),
)
am.add_action("scan", Action(func=lambda request, context: self.scan()))
am.add_request("scan", RequestType(func=lambda request, context: self.scan()))
return am
def _get_session_details(self, session_id: str) -> Session:
"""
Returns the Session object from the given session id.
:param: session_id: ID of the session that needs details retrieved
"""
return self.software_manager.session_manager.sessions_by_uuid[session_id]
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -209,19 +219,27 @@ class IOSoftware(Software):
)
return state
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
def send(
self,
payload: Any,
session_id: Optional[str] = None,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:param payload: The payload to send.
:param session_id: The identifier of the session that the payload is associated with.
:param kwargs: Additional keyword arguments specific to the implementation.
:return: True if the payload was successfully sent, False otherwise.
:return: True if successful, False otherwise.
"""
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id, **kwargs)
return self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
)
@abstractmethod
def receive(self, payload: Any, session_id: str, **kwargs) -> bool: