diff --git a/.gitignore b/.gitignore index fd115d62..f6231bac 100644 --- a/.gitignore +++ b/.gitignore @@ -144,9 +144,11 @@ cython_debug/ # IDE .idea/ docs/source/primaite-dependencies.rst +.vscode/ # outputs src/primaite/outputs/ +simulation_output/ # benchmark session outputs benchmark/output diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f2918aa..d9700f83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ SessionManager. - File System - ability to emulate a node's file system during a simulation - Example notebooks - There is currently 1 jupyter notebook which walks through using PrimAITE 1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP) +- Red Agent Services: + - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database) +- DNS Services: DNS Client and DNS Server ## [2.0.0] - 2023-07-26 diff --git a/docs/index.rst b/docs/index.rst index b2c5cfaa..19f95e95 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -98,6 +98,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/getting_started source/about source/config + source/simulation source/primaite_session source/custom_agent PrimAITE API diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index 7e9fe77f..e5c0d2c8 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -21,3 +21,5 @@ Contents simulation_components/network/router simulation_components/network/switch simulation_components/network/network + simulation_components/system/internal_frame_processing + simulation_components/system/software diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index f4d64b16..cb6d9392 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -2,7 +2,7 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -.. _about: +.. _network: Network ======= diff --git a/docs/source/simulation_components/network/router.rst b/docs/source/simulation_components/network/router.rst index aaa589cc..2dc81d3b 100644 --- a/docs/source/simulation_components/network/router.rst +++ b/docs/source/simulation_components/network/router.rst @@ -2,7 +2,7 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -.. _about: +.. _router: Router Module ============= diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst new file mode 100644 index 00000000..c9f8977a --- /dev/null +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -0,0 +1,58 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +DataManipulationBot +=================== + +The ``DataManipulationBot`` class provides functionality to connect to a ``DatabaseService`` and execute malicious SQL statements. + +Overview +-------- + +The bot is intended to simulate a malicious actor carrying out attacks like: + +- Dropping tables +- Deleting records +- Modifying data +On a database server by abusing an application's trusted database connectivity. + +Usage +----- + +- Create an instance and call ``configure`` to set: + - Target database server IP + - Database password (if needed) + - SQL statement payload +- Call ``run`` to connect and execute the statement. + +The bot handles connecting, executing the statement, and disconnecting. + +Example +------- + +.. code-block:: python + + client_1 = Computer( + hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1" + ) + client_1.power_on() + network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) + client_1.software_manager.install(DataManipulationBot) + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;") + data_manipulation_bot.run() + +This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table. + +Implementation +-------------- + +The bot extends ``DatabaseClient`` and leverages its connectivity. + +- Uses the Application base class for lifecycle management. +- Credentials and target IP set via ``configure``. +- ``run`` handles connecting, executing statement, and disconnecting. +- SQL payload executed via ``query`` method. +- Results in malicious SQL being executed on remote database server. diff --git a/docs/source/simulation_components/system/database_client_server.rst b/docs/source/simulation_components/system/database_client_server.rst new file mode 100644 index 00000000..99bbe25e --- /dev/null +++ b/docs/source/simulation_components/system/database_client_server.rst @@ -0,0 +1,70 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +Database Client Server +====================== + +Database Service +---------------- + +The ``DatabaseService`` provides a SQL database server simulation by extending the base Service class. + +Key capabilities +^^^^^^^^^^^^^^^^ + +- Initialises a SQLite database file in the ``Node``'s ``FileSystem`` upon creation. +- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs. +- Authenticates connections using a configurable password. +- Executes SQL queries against the SQLite database. +- Returns query results and status codes back to clients. +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +^^^^^ +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Clients connect, execute queries, and disconnect. +- Service runs on TCP port 5432 by default. + +Implementation +^^^^^^^^^^^^^^ + +- Uses SQLite for persistent storage. +- Creates the database file within the node's file system. +- Manages client connections in a dictionary by session ID. +- Processes SQL queries via the SQLite cursor and connection. +- Returns results and status codes in a standard dictionary format. +- Extends Service class for integration with ``SoftwareManager``. + +Database Client +--------------- + +The DatabaseClient provides a client interface for connecting to the ``DatabaseService``. + +Key features +^^^^^^^^^^^^ + +- Connects to the ``DatabaseService`` via the ``SoftwareManager``. +- Executes SQL queries and retrieves result sets. +- Handles connecting, querying, and disconnecting. +- Provides a simple ``query`` method for running SQL. + + +Usage +^^^^^ + +- Initialise with server IP address and optional password. +- Connect to the ``DatabaseService`` with ``connect``. +- Execute SQL queries via ``query``. +- Retrieve results in a dictionary. +- Disconnect when finished. + +Implementation +^^^^^^^^^^^^^^ + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Connect and disconnect methods manage sessions. +- Provides easy interface for applications to query database. +- Payloads serialised as dictionaries for transmission. +- Extends base Application class. diff --git a/docs/source/simulation_components/system/dns_client_server.rst b/docs/source/simulation_components/system/dns_client_server.rst new file mode 100644 index 00000000..f57f903b --- /dev/null +++ b/docs/source/simulation_components/system/dns_client_server.rst @@ -0,0 +1,56 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +DNS Client Server +================= + +DNS Server +---------- +Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class. + +Key capabilities +^^^^^^^^^^^^^^^^ + +- Simulates DNS requests and DNSPacket transfer across a network +- Registers domain names and the IP Address linked to the domain name +- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +^^^^^ +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) + +Implementation +^^^^^^^^^^^^^^ + +- DNS request and responses use a ``DNSPacket`` object +- Extends Service class for integration with ``SoftwareManager``. + +DNS Client +---------- + +The DNSClient provides a client interface for connecting to the ``DNSServer``. + +Key features +^^^^^^^^^^^^ + +- Connects to the ``DNSServer`` via the ``SoftwareManager``. +- Executes DNS lookup requests and keeps a cache of known domain name IP addresses. +- Handles connection to DNSServer and querying for domain name IP addresses. + +Usage +^^^^^ + +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) +- Execute domain name checks with ``check_domain_exists``. +- ``DNSClient`` will automatically add the IP Address of the domain into its cache + +Implementation +^^^^^^^^^^^^^^ + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Provides easy interface for Nodes to find IP addresses via domain names. +- Extends base Service class. diff --git a/docs/source/simulation_components/system/internal_frame_processing.rst b/docs/source/simulation_components/system/internal_frame_processing.rst new file mode 100644 index 00000000..9c5356cc --- /dev/null +++ b/docs/source/simulation_components/system/internal_frame_processing.rst @@ -0,0 +1,98 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _internal_frame_processing: + +Internal Frame Processing +========================= + +Inbound +------- + +At the NIC +^^^^^^^^^^ +When a Frame is received on the Node's NIC: + +- The NIC checks if it is enabled. If so, it will process the Frame. +- The Frame's received timestamp is set. +- The Frame is captured by the NIC's PacketCapture if configured. +- The NIC decrements the IP Packet's TTL by 1. +- The NIC calls the Node's ``receive_frame`` method, passing itself as the receiving NIC and the Frame. + + +At the Node +^^^^^^^^^^^ + +When ``receive_frame`` is called on the Node: + +- The source IP address is added to the ARP cache if not already present. +- The Frame's protocol is checked: + - If ARP or ICMP, the Frame is passed to that protocol's handler method. + - Otherwise it is passed to the SessionManager's ``receive_frame`` method. + +At the SessionManager +^^^^^^^^^^^^^^^^^^^^^ + +When ``receive_frame`` is called on the SessionManager: + +- It extracts the key session details from the Frame: + - Protocol (TCP, UDP, etc) + - Source IP + - Destination IP + - Source Port + - Destination Port +- It checks if an existing Session matches these details. +- If no match, a new Session is created to represent this exchange. +- The payload and new/existing Session ID are passed to the SoftwareManager's ``receive_payload_from_session_manager`` method. + +At the SoftwareManager +^^^^^^^^^^^^^^^^^^^^^^ + +Inside ``receive_payload_from_session_manager``: + +- The SoftwareManager checks its port/protocol mapping to find which Service or Application is listening on the destination port and protocol. +- The payload and Session ID are forwarded to that receiver Service/Application instance via their ``receive`` method. +- The Service/Application can then process the payload as needed. + +Outbound +-------- + +At the Service/Application +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When a Service or Application needs to send a payload: + +- It calls the SoftwareManager's ``send_payload_to_session_manager`` method. +- Passes the payload, and either destination IP and destination port for new payloads, or session id for existing sessions. + +At the SoftwareManager +^^^^^^^^^^^^^^^^^^^^^^ + +Inside ``send_payload_to_session_manager``: + +- The SoftwareManager forwards the payload and details through to to the SessionManager's ``receive_payload_from_software_manager`` method. + +At the SessionManager +^^^^^^^^^^^^^^^^^^^^^ + +When ``receive_payload_from_software_manager`` is called: + +- If a Session ID was provided, it looks up the Session. +- Gets the destination MAC address by checking the ARP cache. +- If no Session ID was provided, the destination Port, IP address and Mac Address are used along with the outbound IP Address and Mac Address to create a new Session. +- Calls `send_payload_to_nic`` to construct and send the Frame. + +When ``send_payload_to_nic`` is called: + +- It constructs a new Frame with the payload, using the source NIC's MAC, source IP, destination MAC, etc. +- The outbound NIC is looked up via the ARP cache based on destination IP. +- The constructed Frame is passed to the outbound NIC's ``send_frame`` method. + +At the NIC +^^^^^^^^^^ + +When ``send_frame`` is called: + +- The NIC checks if it is enabled before sending. +- If enabled, it sends the Frame out to the connected Link. diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst new file mode 100644 index 00000000..275fdaf9 --- /dev/null +++ b/docs/source/simulation_components/system/software.rst @@ -0,0 +1,20 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +Software +======== + + + + +Contents +######## + +.. toctree:: + :maxdepth: 8 + + database_client_server + data_manipulation_bot + dns_client_server diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 1cfe7f49..8c55542f 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -1,5 +1,14 @@ +from datetime import datetime + from primaite import _PRIMAITE_ROOT -TEMP_SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" +SIM_OUTPUT = None "A path at the repo root dir to use temporarily for sim output testing while in dev." # TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path + +if not SIM_OUTPUT: + session_timestamp = datetime.now() + date_dir = session_timestamp.strftime("%Y-%m-%d") + sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path + SIM_OUTPUT.mkdir(exist_ok=True, parents=True) diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index baf7bd2c..a2e1550c 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -149,8 +149,8 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.file_system.file_system_file_type import FileSystemFileType\n", - "from primaite.simulator.file_system.file_system_file import FileSystemFile" + "from primaite.simulator.file_system.file_type import FileType\n", + "from primaite.simulator.file_system.file_system import File" ] }, { @@ -160,7 +160,7 @@ "outputs": [], "source": [ "my_pc_downloads_folder = my_pc.file_system.create_folder(\"downloads\")\n", - "my_pc_downloads_folder.add_file(FileSystemFile(name=\"firefox_installer.zip\",file_type=FileSystemFileType.ZIP))" + "my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",file_type=FileType.ZIP))" ] }, { @@ -171,7 +171,7 @@ { "data": { "text/plain": [ - "FileSystemFile(uuid='7d56a563-ecc0-4011-8c97-240dd6c885c0', name='favicon.ico', size=40.0, file_type=, action_manager=None)" + "File(uuid='7d56a563-ecc0-4011-8c97-240dd6c885c0', name='favicon.ico', size=40.0, file_type=, action_manager=None)" ] }, "execution_count": 9, @@ -181,7 +181,7 @@ ], "source": [ "my_server_folder = my_server.file_system.create_folder(\"static\")\n", - "my_server.file_system.create_file(\"favicon.ico\", file_type=FileSystemFileType.PNG)" + "my_server.file_system.create_file(\"favicon.ico\", file_type=FileType.PNG)" ] }, { diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 0fbc33fd..78e6139f 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -222,6 +222,9 @@ class SimComponent(BaseModel): :param action: List describing the action to apply to this object. :type action: List[str] + + :param: context: Dict containing context for actions + :type context: Dict """ if self.action_manager is None: return diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 79159e60..b2037729 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -1,242 +1,519 @@ -from random import choice +from __future__ import annotations + +import math +import os.path +import shutil +from pathlib import Path from typing import Dict, Optional +from prettytable import MARKDOWN, PrettyTable + from primaite import getLogger from primaite.simulator.core import SimComponent -from primaite.simulator.file_system.file_system_file import FileSystemFile -from primaite.simulator.file_system.file_system_file_type import FileSystemFileType -from primaite.simulator.file_system.file_system_folder import FileSystemFolder +from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension +from primaite.simulator.system.core.sys_log import SysLog _LOGGER = getLogger(__name__) -class FileSystem(SimComponent): - """Class that contains all the simulation File System.""" +def convert_size(size_bytes: int) -> str: + """ + Convert a file size from bytes to a string with a more human-readable format. - folders: Dict[str, FileSystemFolder] = {} - """List containing all the folders in the file system.""" + This function takes the size of a file in bytes and converts it to a string representation with appropriate size + units (B, KB, MB, GB, etc.). + + :param size_bytes: The size of the file in bytes. + :return: The human-readable string representation of the file size. + """ + if size_bytes == 0: + return "0 B" + + # Tuple of size units starting from Bytes up to Yottabytes + size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") + + # Calculate the index (i) that will be used to select the appropriate size unit from size_name + i = int(math.floor(math.log(size_bytes, 1024))) + + # Calculate the adjusted size value (s) in terms of the new size unit + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + + return f"{s} {size_name[i]}" + + +class FileSystemItemABC(SimComponent): + """ + Abstract base class for file system items used in the file system simulation. + + :ivar name: The name of the FileSystemItemABC. + """ + + name: str + "The name of the FileSystemItemABC." def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - :return: Current state of this object and child objects. - :rtype: Dict """ state = super().describe_state() - state.update({"folders": {uuid: folder.describe_state() for uuid, folder in self.folders.items()}}) + state.update( + { + "name": self.name, + } + ) return state - def get_folders(self) -> Dict: - """Returns the list of folders.""" - return self.folders + @property + def size_str(self) -> str: + """ + Get the file size in a human-readable string format. + + This property makes use of the :func:`convert_size` function to convert the `self.size` attribute to a string + that is easier to read and understand. + + :return: The human-readable string representation of the file size. + """ + return convert_size(self.size) + + +class FileSystem(SimComponent): + """Class that contains all the simulation File System.""" + + folders: Dict[str, Folder] = {} + "List containing all the folders in the file system." + _folders_by_name: Dict[str, Folder] = {} + sys_log: SysLog + sim_root: Path + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Ensure a default root folder + if not self.folders: + self.create_folder("root") + + @property + def size(self) -> int: + """ + Calculate and return the total size of all folders in the file system. + + :return: The sum of the sizes of all folders in the file system. + """ + return sum(folder.size for folder in self.folders.values()) + + def show(self, markdown: bool = False, full: bool = False): + """ + Prints a table of the FileSystem, displaying either just folders or full files. + + :param markdown: Flag indicating if output should be in markdown format. + :param full: Flag indicating if to show full files. + """ + headers = ["Folder", "Size"] + if full: + headers[0] = "File Path" + table = PrettyTable(headers) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} File System" + for folder in self.folders.values(): + if not full: + table.add_row([folder.name, folder.size_str]) + else: + for file in folder.files.values(): + table.add_row([file.path, file.size_str]) + if full: + print(table.get_string(sortby="File Path")) + else: + print(table.get_string(sortby="Folder")) + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + :return: Current state of this object and child objects. + """ + state = super().describe_state() + state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()} + return state + + def create_folder(self, folder_name: str) -> Folder: + """ + Creates a Folder and adds it to the list of folders. + + :param folder_name: The name of the folder. + """ + # check if folder with name already exists + if self.get_folder(folder_name): + raise Exception(f"Cannot create folder as it already exists: {folder_name}") + + folder = Folder(name=folder_name, fs=self) + + self.folders[folder.uuid] = folder + self._folders_by_name[folder.name] = folder + self.sys_log.info(f"Created folder /{folder.name}") + return folder + + def delete_folder(self, folder_name: str): + """ + Deletes a folder, removes it from the folders list and removes any child folders and files. + + :param folder_name: The name of the folder. + """ + if folder_name == "root": + self.sys_log.warning("Cannot delete the root folder.") + return + folder = self._folders_by_name.get(folder_name) + if folder: + for file in folder.files.values(): + self.delete_file(file) + self.folders.pop(folder.uuid) + self._folders_by_name.pop(folder.name) + self.sys_log.info(f"Deleted folder /{folder.name} and its contents") + else: + _LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}") def create_file( self, file_name: str, - size: Optional[float] = None, - file_type: Optional[FileSystemFileType] = None, - folder: Optional[FileSystemFolder] = None, - folder_uuid: Optional[str] = None, - ) -> FileSystemFile: + size: Optional[int] = None, + file_type: Optional[FileType] = None, + folder_name: Optional[str] = None, + real: bool = False, + ) -> File: """ - Creates a FileSystemFile and adds it to the list of files. + Creates a File and adds it to the list of files. - If no size or file_type are provided, one will be chosen randomly. - If no folder_uuid or folder is provided, a new folder will be created. - - :param: file_name: The file name - :type: file_name: str - - :param: size: The size the file takes on disk. - :type: size: Optional[float] - - :param: file_type: The type of the file - :type: Optional[FileSystemFileType] - - :param: folder: The folder to add the file to - :type: folder: Optional[FileSystemFolder] - - :param: folder_uuid: The uuid of the folder to add the file to - :type: folder_uuid: Optional[str] + :param file_name: The file name. + :param size: The size the file takes on disk in bytes. + :param file_type: The type of the file. + :param folder_name: The folder to add the file to. + :param real: "Indicates whether the File is actually a real file in the Node sim fs output." """ - file = None - folder = None - - if file_type is None: - file_type = self.get_random_file_type() - - # if no folder uuid provided, create a folder and add file to it - if folder_uuid is not None: - # otherwise check for existence and add file - folder = self.get_folder_by_id(folder_uuid) - - if folder is not None: + if folder_name: # check if file with name already exists - if folder.get_file_by_name(file_name): - raise Exception(f'File with name "{file_name}" already exists.') - - file = FileSystemFile(name=file_name, size=size, file_type=file_type) - folder.add_file(file=file) + folder = self._folders_by_name.get(folder_name) + # If not then create it + if not folder: + folder = self.create_folder(folder_name) else: - # check if a "root" folder exists - folder = self.get_folder_by_name("root") - if folder is None: - # create a root folder - folder = FileSystemFolder(name="root") + # Use root folder if folder_name not supplied + folder = self._folders_by_name["root"] - # add file to root folder - file = FileSystemFile(name=file_name, size=size, file_type=file_type) - folder.add_file(file) - self.folders[folder.uuid] = folder + # Create the file and add it to the folder + file = File( + name=file_name, + sim_size=size, + file_type=file_type, + folder=folder, + real=real, + sim_path=self.sim_root if real else None, + ) + folder.add_file(file) + self.sys_log.info(f"Created file /{file.path}") return file - def create_folder( - self, - folder_name: str, - ) -> FileSystemFolder: + def get_file(self, folder_name: str, file_name: str) -> Optional[File]: """ - Creates a FileSystemFolder and adds it to the list of folders. + Retrieve a file by its name from a specific folder. - :param: folder_name: The name of the folder - :type: folder_name: str + :param folder_name: The name of the folder where the file resides. + :param file_name: The name of the file to be retrieved, including its extension. + :return: An instance of File if it exists, otherwise `None`. """ - # check if folder with name already exists - if self.get_folder_by_name(folder_name): - raise Exception(f'Folder with name "{folder_name}" already exists.') + folder = self.get_folder(folder_name) + if folder: + return folder.get_file(file_name) + self.fs.sys_log.info(f"file not found /{folder_name}/{file_name}") - folder = FileSystemFolder(name=folder_name) - - self.folders[folder.uuid] = folder - return folder - - def delete_file(self, file: Optional[FileSystemFile] = None): + def delete_file(self, folder_name: str, file_name: str): """ - Deletes a file and removes it from the files list. + Delete a file by its name from a specific folder. - :param file: The file to delete - :type file: Optional[FileSystemFile] + :param folder_name: The name of the folder containing the file. + :param file_name: The name of the file to be deleted, including its extension. """ - # iterate through folders to delete the item with the matching uuid - for key in self.folders: - self.get_folder_by_id(key).remove_file(file) + folder = self.get_folder(folder_name) + if folder: + file = folder.get_file(file_name) + if file: + folder.remove_file(file) + self.sys_log.info(f"Deleted file /{file.path}") - def delete_folder(self, folder: FileSystemFolder): + def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): """ - Deletes a folder, removes it from the folders list and removes any child folders and files. + Move a file from one folder to another. - :param folder: The folder to remove - :type folder: FileSystemFolder + :param src_folder_name: The name of the source folder containing the file. + :param src_file_name: The name of the file to be moved. + :param dst_folder_name: The name of the destination folder. """ - if folder is None or not isinstance(folder, FileSystemFolder): - raise Exception(f"Invalid folder: {folder}") + file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) + if file: + src_folder = file.folder - if self.folders.get(folder.uuid): - del self.folders[folder.uuid] + # remove file from src + src_folder.remove_file(file) + dst_folder = self.get_folder(folder_name=dst_folder_name) + if not dst_folder: + dst_folder = self.create_folder(dst_folder_name) + # add file to dst + dst_folder.add_file(file) + if file.real: + old_sim_path = file.sim_path + file.sim_path = file.folder.fs.sim_root / file.path + file.sim_path.parent.mkdir(exist_ok=True) + shutil.move(old_sim_path, file.sim_path) + + def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): + """ + Copy a file from one folder to another. + + :param src_folder_name: The name of the source folder containing the file. + :param src_file_name: The name of the file to be copied. + :param dst_folder_name: The name of the destination folder. + """ + file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) + if file: + dst_folder = self.get_folder(folder_name=dst_folder_name) + if not dst_folder: + dst_folder = self.create_folder(dst_folder_name) + new_file = file.make_copy(dst_folder=dst_folder) + dst_folder.add_file(new_file) + if file.real: + new_file.sim_path.parent.mkdir(exist_ok=True) + shutil.copy2(file.sim_path, new_file.sim_path) + + def get_folder(self, folder_name: str) -> Optional[Folder]: + """ + Get a folder by its name if it exists. + + :param folder_name: The folder name. + :return: The matching Folder. + """ + return self._folders_by_name.get(folder_name) + + def get_folder_by_id(self, folder_uuid: str) -> Optional[Folder]: + """ + Get a folder by its uuid if it exists. + + :param folder_uuid: The folder uuid. + :return: The matching Folder. + """ + return self.folders.get(folder_uuid) + + +class Folder(FileSystemItemABC): + """Simulation Folder.""" + + fs: FileSystem + "The FileSystem the Folder is in." + files: Dict[str, File] = {} + "Files stored in the folder." + _files_by_name: Dict[str, File] = {} + "Files by their name as .." + is_quarantined: bool = False + "Flag that marks the folder as quarantined if true." + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + :return: Current state of this object and child objects. + """ + state = super().describe_state() + state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()} + state["is_quarantined"] = self.is_quarantined + return state + + def show(self, markdown: bool = False): + """ + Display the contents of the Folder in tabular format. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ + table = PrettyTable(["File", "Size"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.fs.sys_log.hostname} File System Folder ({self.name})" + for file in self.files.values(): + table.add_row([file.name, file.size_str]) + print(table.get_string(sortby="File")) + + @property + def size(self) -> int: + """ + Calculate and return the total size of all files in the folder. + + :return: The total size of all files in the folder. If no files exist or all have `None` + size, returns 0. + """ + return sum(file.size for file in self.files.values() if file.size is not None) + + def get_file(self, file_name: str) -> Optional[File]: + """ + Get a file by its name. + + File name must be the filename and prefix, like 'memo.docx'. + + :param file_name: The file name. + :return: The matching File. + """ + # TODO: Increment read count? + return self._files_by_name.get(file_name) + + def get_file_by_id(self, file_uuid: str) -> File: + """ + Get a file by its uuid. + + :param file_uuid: The file uuid. + :return: The matching File. + """ + return self.files.get(file_uuid) + + def add_file(self, file: File): + """ + Adds a file to the folder. + + :param File file: The File object to be added to the folder. + :raises Exception: If the provided `file` parameter is None or not an instance of the + `File` class. + """ + if file is None or not isinstance(file, File): + raise Exception(f"Invalid file: {file}") + + # check if file with id already exists in folder + if file.uuid in self.files: + _LOGGER.debug(f"File with id {file.uuid} already exists in folder") else: - _LOGGER.debug(f"File with UUID {folder.uuid} was not found.") + # add to list + self.files[file.uuid] = file + self._files_by_name[file.name] = file + file.folder = self - def move_file(self, file: FileSystemFile, src_folder: FileSystemFolder, target_folder: FileSystemFolder): + def remove_file(self, file: Optional[File]): """ - Moves a file from one folder to another. + Removes a file from the folder list. - can provide + The method can take a File object or a file id. - :param: file: The file to move - :type: file: FileSystemFile - - :param: src_folder: The folder where the file is located - :type: FileSystemFolder - - :param: target_folder: The folder where the file should be moved to - :type: FileSystemFolder + :param file: The file to remove """ - # check that the folders exist - if src_folder is None: - raise Exception("Source folder not provided") + if file is None or not isinstance(file, File): + raise Exception(f"Invalid file: {file}") - if target_folder is None: - raise Exception("Target folder not provided") + if self.files.get(file.uuid): + self.files.pop(file.uuid) + self._files_by_name.pop(file.name) + else: + _LOGGER.debug(f"File with UUID {file.uuid} was not found.") - if file is None: - raise Exception("File to be moved is None") + def quarantine(self): + """Quarantines the File System Folder.""" + if not self.is_quarantined: + self.is_quarantined = True + self.fs.sys_log.info(f"Quarantined folder ./{self.name}") - # check if file with name already exists - if target_folder.get_file_by_name(file.name): - raise Exception(f'Folder with name "{file.name}" already exists.') + def unquarantine(self): + """Unquarantine of the File System Folder.""" + if self.is_quarantined: + self.is_quarantined = False + self.fs.sys_log.info(f"Quarantined folder ./{self.name}") - # remove file from src - src_folder.remove_file(file) + def quarantine_status(self) -> bool: + """Returns true if the folder is being quarantined.""" + return self.is_quarantined - # add file to target - target_folder.add_file(file) - def copy_file(self, file: FileSystemFile, src_folder: FileSystemFolder, target_folder: FileSystemFolder): +class File(FileSystemItemABC): + """ + Class representing a file in the simulation. + + :ivar Folder folder: The folder in which the file resides. + :ivar FileType file_type: The type of the file. + :ivar Optional[int] sim_size: The simulated file size. + :ivar bool real: Indicates if the file is actually a real file in the Node sim fs output. + :ivar Optional[Path] sim_path: The path if the file is real. + """ + + folder: Folder + "The Folder the File is in." + file_type: FileType + "The type of File." + sim_size: Optional[int] = None + "The simulated file size." + real: bool = False + "Indicates whether the File is actually a real file in the Node sim fs output." + sim_path: Optional[Path] = None + "The Path if real is True." + + def __init__(self, **kwargs): """ - Copies a file from one folder to another. + Initialise File class. - can provide - - :param: file: The file to move - :type: file: FileSystemFile - - :param: src_folder: The folder where the file is located - :type: FileSystemFolder - - :param: target_folder: The folder where the file should be moved to - :type: FileSystemFolder + :param name: The name of the file. + :param file_type: The FileType of the file + :param size: The size of the FileSystemItemABC """ - if src_folder is None: - raise Exception("Source folder not provided") + has_extension = "." in kwargs["name"] - if target_folder is None: - raise Exception("Target folder not provided") + # Attempt to use the file type extension to set/override the FileType + if has_extension: + extension = kwargs["name"].split(".")[-1] + kwargs["file_type"] = get_file_type_from_extension(extension) + else: + # If the file name does not have a extension, override file type to FileType.UNKNOWN + if not kwargs["file_type"]: + kwargs["file_type"] = FileType.UNKNOWN + if kwargs["file_type"] != FileType.UNKNOWN: + kwargs["name"] = f"{kwargs['name']}.{kwargs['file_type'].name.lower()}" - if file is None: - raise Exception("File to be moved is None") + # set random file size if none provided + if not kwargs.get("sim_size"): + kwargs["sim_size"] = kwargs["file_type"].default_size + super().__init__(**kwargs) + if self.real: + self.sim_path = self.folder.fs.sim_root / self.path + if not self.sim_path.exists(): + self.sim_path.parent.mkdir(exist_ok=True, parents=True) + with open(self.sim_path, mode="a"): + pass - # check if file with name already exists - if target_folder.get_file_by_name(file.name): - raise Exception(f'Folder with name "{file.name}" already exists.') - - # add file to target - target_folder.add_file(file) - - def get_file_by_id(self, file_id: str) -> FileSystemFile: - """Checks if the file exists in any file system folders.""" - for key in self.folders: - file = self.folders[key].get_file_by_id(file_id=file_id) - if file is not None: - return file - - def get_folder_by_name(self, folder_name: str) -> Optional[FileSystemFolder]: + def make_copy(self, dst_folder: Folder) -> File: """ - Returns a the first folder with a matching name. + Create a copy of the current File object in the given destination folder. - :return: Returns the first FileSydtemFolder with a matching name + :param Folder dst_folder: The destination folder for the copied file. + :return: A new File object that is a copy of the current file. """ - matching_folder = None - for key in self.folders: - if self.folders[key].name == folder_name: - matching_folder = self.folders[key] - break - return matching_folder + return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"})) - def get_folder_by_id(self, folder_id: str) -> FileSystemFolder: + @property + def path(self) -> str: """ - Checks if the folder exists. + Get the path of the file in the file system. - :param: folder_id: The id of the folder to find - :type: folder_id: str + :return: The full path of the file. """ - return self.folders[folder_id] + return f"{self.folder.name}/{self.name}" - def get_random_file_type(self) -> FileSystemFileType: + @property + def size(self) -> int: """ - Returns a random FileSystemFileTypeEnum. + Get the size of the file in bytes. - :return: A random file type Enum + :return: The size of the file in bytes. """ - return choice(list(FileSystemFileType)) + if self.real: + return os.path.getsize(self.sim_path) + return self.sim_size + + def describe_state(self) -> Dict: + """Produce a dictionary describing the current state of this object.""" + state = super().describe_state() + state["size"] = self.size + state["file_type"] = self.file_type.name + return state diff --git a/src/primaite/simulator/file_system/file_system_file.py b/src/primaite/simulator/file_system/file_system_file.py deleted file mode 100644 index c25f5973..00000000 --- a/src/primaite/simulator/file_system/file_system_file.py +++ /dev/null @@ -1,55 +0,0 @@ -from random import choice -from typing import Dict - -from primaite.simulator.file_system.file_system_file_type import file_type_sizes_KB, FileSystemFileType -from primaite.simulator.file_system.file_system_item_abc import FileSystemItem - - -class FileSystemFile(FileSystemItem): - """Class that represents a file in the simulation.""" - - file_type: FileSystemFileType = None - """The type of the FileSystemFile""" - - def __init__(self, **kwargs): - """ - Initialise FileSystemFile class. - - :param name: The name of the file. - :type name: str - - :param file_type: The FileSystemFileType of the file - :type file_type: Optional[FileSystemFileType] - - :param size: The size of the FileSystemItem - :type size: Optional[float] - """ - # set random file type if none provided - - # set random file type if none provided - if kwargs.get("file_type") is None: - kwargs["file_type"] = choice(list(FileSystemFileType)) - - # set random file size if none provided - if kwargs.get("size") is None: - kwargs["size"] = file_type_sizes_KB[kwargs["file_type"]] - - super().__init__(**kwargs) - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - state = super().describe_state() - state.update( - { - "uuid": self.uuid, - "file_type": self.file_type.name, - } - ) - return state diff --git a/src/primaite/simulator/file_system/file_system_file_type.py b/src/primaite/simulator/file_system/file_system_file_type.py deleted file mode 100644 index 88aeb430..00000000 --- a/src/primaite/simulator/file_system/file_system_file_type.py +++ /dev/null @@ -1,132 +0,0 @@ -from enum import Enum - - -class FileSystemFileType(str, Enum): - """An enumeration of common file types.""" - - UNKNOWN = 0 - "Unknown file type." - - # Text formats - TXT = 1 - "Plain text file." - DOC = 2 - "Microsoft Word document (.doc)" - DOCX = 3 - "Microsoft Word document (.docx)" - PDF = 4 - "Portable Document Format." - HTML = 5 - "HyperText Markup Language file." - XML = 6 - "Extensible Markup Language file." - CSV = 7 - "Comma-Separated Values file." - - # Spreadsheet formats - XLS = 8 - "Microsoft Excel file (.xls)" - XLSX = 9 - "Microsoft Excel file (.xlsx)" - - # Image formats - JPEG = 10 - "JPEG image file." - PNG = 11 - "PNG image file." - GIF = 12 - "GIF image file." - BMP = 13 - "Bitmap image file." - - # Audio formats - MP3 = 14 - "MP3 audio file." - WAV = 15 - "WAV audio file." - - # Video formats - MP4 = 16 - "MP4 video file." - AVI = 17 - "AVI video file." - MKV = 18 - "MKV video file." - FLV = 19 - "FLV video file." - - # Presentation formats - PPT = 20 - "Microsoft PowerPoint file (.ppt)" - PPTX = 21 - "Microsoft PowerPoint file (.pptx)" - - # Web formats - JS = 22 - "JavaScript file." - CSS = 23 - "Cascading Style Sheets file." - - # Programming languages - PY = 24 - "Python script file." - C = 25 - "C source code file." - CPP = 26 - "C++ source code file." - JAVA = 27 - "Java source code file." - - # Compressed file types - RAR = 28 - "RAR archive file." - ZIP = 29 - "ZIP archive file." - TAR = 30 - "TAR archive file." - GZ = 31 - "Gzip compressed file." - - # Database file types - MDF = 32 - "MS SQL Server primary database file" - NDF = 33 - "MS SQL Server secondary database file" - LDF = 34 - "MS SQL Server transaction log" - - -file_type_sizes_KB = { - FileSystemFileType.UNKNOWN: 0, - FileSystemFileType.TXT: 4, - FileSystemFileType.DOC: 50, - FileSystemFileType.DOCX: 30, - FileSystemFileType.PDF: 100, - FileSystemFileType.HTML: 15, - FileSystemFileType.XML: 10, - FileSystemFileType.CSV: 15, - FileSystemFileType.XLS: 100, - FileSystemFileType.XLSX: 25, - FileSystemFileType.JPEG: 100, - FileSystemFileType.PNG: 40, - FileSystemFileType.GIF: 30, - FileSystemFileType.BMP: 300, - FileSystemFileType.MP3: 5000, - FileSystemFileType.WAV: 25000, - FileSystemFileType.MP4: 25000, - FileSystemFileType.AVI: 50000, - FileSystemFileType.MKV: 50000, - FileSystemFileType.FLV: 15000, - FileSystemFileType.PPT: 200, - FileSystemFileType.PPTX: 100, - FileSystemFileType.JS: 10, - FileSystemFileType.CSS: 5, - FileSystemFileType.PY: 5, - FileSystemFileType.C: 5, - FileSystemFileType.CPP: 10, - FileSystemFileType.JAVA: 10, - FileSystemFileType.RAR: 1000, - FileSystemFileType.ZIP: 1000, - FileSystemFileType.TAR: 1000, - FileSystemFileType.GZ: 800, -} diff --git a/src/primaite/simulator/file_system/file_system_folder.py b/src/primaite/simulator/file_system/file_system_folder.py deleted file mode 100644 index 4e461a3a..00000000 --- a/src/primaite/simulator/file_system/file_system_folder.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Dict, Optional - -from primaite import getLogger -from primaite.simulator.file_system.file_system_file import FileSystemFile -from primaite.simulator.file_system.file_system_item_abc import FileSystemItem - -_LOGGER = getLogger(__name__) - - -class FileSystemFolder(FileSystemItem): - """Simulation FileSystemFolder.""" - - files: Dict[str, FileSystemFile] = {} - """List of files stored in the folder.""" - - is_quarantined: bool = False - """Flag that marks the folder as quarantined if true.""" - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - state = super().describe_state() - state.update( - { - "files": {uuid: file.describe_state() for uuid, file in self.files.items()}, - "is_quarantined": self.is_quarantined, - } - ) - return state - - def get_file_by_id(self, file_id: str) -> FileSystemFile: - """Return a FileSystemFile with the matching id.""" - return self.files.get(file_id) - - def get_file_by_name(self, file_name: str) -> FileSystemFile: - """Return a FileSystemFile with the matching id.""" - return next((f for f in list(self.files) if f.name == file_name), None) - - def add_file(self, file: FileSystemFile): - """Adds a file to the folder list.""" - if file is None or not isinstance(file, FileSystemFile): - raise Exception(f"Invalid file: {file}") - - # check if file with id already exists in folder - if file.uuid in self.files: - _LOGGER.debug(f"File with id {file.uuid} already exists in folder") - else: - # add to list - self.files[file.uuid] = file - self.size += file.size - - def remove_file(self, file: Optional[FileSystemFile]): - """ - Removes a file from the folder list. - - The method can take a FileSystemFile object or a file id. - - :param: file: The file to remove - :type: Optional[FileSystemFile] - """ - if file is None or not isinstance(file, FileSystemFile): - raise Exception(f"Invalid file: {file}") - - if self.files.get(file.uuid): - del self.files[file.uuid] - - self.size -= file.size - else: - _LOGGER.debug(f"File with UUID {file.uuid} was not found.") - - def quarantine(self): - """Quarantines the File System Folder.""" - self.is_quarantined = True - - def end_quarantine(self): - """Ends the quarantine of the File System Folder.""" - self.is_quarantined = False - - def quarantine_status(self) -> bool: - """Returns true if the folder is being quarantined.""" - return self.is_quarantined diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py deleted file mode 100644 index 3b368819..00000000 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Dict - -from primaite.simulator.core import SimComponent - - -class FileSystemItem(SimComponent): - """Abstract base class for FileSystemItems used in the file system simulation.""" - - name: str - """The name of the FileSystemItem.""" - - size: float = 0 - """The size the item takes up on disk.""" - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - state = super().describe_state() - state.update( - { - "name": self.name, - "size": self.size, - } - ) - return state diff --git a/src/primaite/simulator/file_system/file_type.py b/src/primaite/simulator/file_system/file_type.py new file mode 100644 index 00000000..f87cd86f --- /dev/null +++ b/src/primaite/simulator/file_system/file_type.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from enum import Enum +from random import choice +from typing import Any + + +class FileType(Enum): + """An enumeration of common file types.""" + + UNKNOWN = 0 + "Unknown file type." + + # Text formats + TXT = 1 + "Plain text file." + DOC = 2 + "Microsoft Word document (.doc)" + DOCX = 3 + "Microsoft Word document (.docx)" + PDF = 4 + "Portable Document Format." + HTML = 5 + "HyperText Markup Language file." + XML = 6 + "Extensible Markup Language file." + CSV = 7 + "Comma-Separated Values file." + + # Spreadsheet formats + XLS = 8 + "Microsoft Excel file (.xls)" + XLSX = 9 + "Microsoft Excel file (.xlsx)" + + # Image formats + JPEG = 10 + "JPEG image file." + PNG = 11 + "PNG image file." + GIF = 12 + "GIF image file." + BMP = 13 + "Bitmap image file." + + # Audio formats + MP3 = 14 + "MP3 audio file." + WAV = 15 + "WAV audio file." + + # Video formats + MP4 = 16 + "MP4 video file." + AVI = 17 + "AVI video file." + MKV = 18 + "MKV video file." + FLV = 19 + "FLV video file." + + # Presentation formats + PPT = 20 + "Microsoft PowerPoint file (.ppt)" + PPTX = 21 + "Microsoft PowerPoint file (.pptx)" + + # Web formats + JS = 22 + "JavaScript file." + CSS = 23 + "Cascading Style Sheets file." + + # Programming languages + PY = 24 + "Python script file." + C = 25 + "C source code file." + CPP = 26 + "C++ source code file." + JAVA = 27 + "Java source code file." + + # Compressed file types + RAR = 28 + "RAR archive file." + ZIP = 29 + "ZIP archive file." + TAR = 30 + "TAR archive file." + GZ = 31 + "Gzip compressed file." + + # Database file types + DB = 32 + "Generic DB file. Used by sqlite3." + + @classmethod + def _missing_(cls, value: Any) -> FileType: + return cls.UNKNOWN + + @classmethod + def random(cls) -> FileType: + """ + Returns a random FileType. + + :return: A random FileType. + """ + return choice(list(FileType)) + + @property + def default_size(self) -> int: + """ + Get the default size of the FileType in bytes. + + Returns 0 if a default size does not exist. + """ + size = file_type_sizes_bytes[self] + return size if size else 0 + + +def get_file_type_from_extension(file_type_extension: str) -> FileType: + """ + Get a FileType from a file type extension. + + If a matching extension does not exist, FileType.UNKNOWN is returned. + + :param file_type_extension: A file type extension. + :return: A file type extension. + """ + try: + return FileType[file_type_extension.upper()] + except KeyError: + return FileType.UNKNOWN + + +file_type_sizes_bytes = { + FileType.UNKNOWN: 0, + FileType.TXT: 4096, + FileType.DOC: 51200, + FileType.DOCX: 30720, + FileType.PDF: 102400, + FileType.HTML: 15360, + FileType.XML: 10240, + FileType.CSV: 15360, + FileType.XLS: 102400, + FileType.XLSX: 25600, + FileType.JPEG: 102400, + FileType.PNG: 40960, + FileType.GIF: 30720, + FileType.BMP: 307200, + FileType.MP3: 5120000, + FileType.WAV: 25600000, + FileType.MP4: 25600000, + FileType.AVI: 51200000, + FileType.MKV: 51200000, + FileType.FLV: 15360000, + FileType.PPT: 204800, + FileType.PPTX: 102400, + FileType.JS: 10240, + FileType.CSS: 5120, + FileType.PY: 5120, + FileType.C: 5120, + FileType.CPP: 10240, + FileType.JAVA: 10240, + FileType.RAR: 1024000, + FileType.ZIP: 1024000, + FileType.TAR: 1024000, + FileType.GZ: 819200, + FileType.DB: 15360000, +} diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 1c7bbec7..f3afad12 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -6,7 +6,7 @@ from networkx import MultiGraph from prettytable import MARKDOWN, PrettyTable from primaite import getLogger -from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent +from primaite.simulator.core import Action, ActionManager, SimComponent from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import Router @@ -29,6 +29,8 @@ class Network(SimComponent): nodes: Dict[str, Node] = {} links: Dict[str, Link] = {} + _node_id_map: Dict[int, Node] = {} + _link_id_map: Dict[int, Node] = {} def __init__(self, **kwargs): """ @@ -47,7 +49,7 @@ class Network(SimComponent): am.add_action( "node", Action( - func = self._node_action_manager + func=self._node_action_manager # func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context), ), ) @@ -161,8 +163,8 @@ class Network(SimComponent): state = super().describe_state() state.update( { - "nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()}, - "links": {uuid: link.describe_state() for uuid, link in self.links.items()}, + "nodes": {i for i, node in self._node_id_map.items()}, + "links": {i: link.describe_state() for i, link in self._link_id_map.items()}, } ) return state @@ -179,10 +181,11 @@ class Network(SimComponent): _LOGGER.warning(f"Can't add node {node.uuid}. It is already in the network.") return self.nodes[node.uuid] = node + self._node_id_map[len(self.nodes)] = node node.parent = self self._nx_graph.add_node(node.hostname) _LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}") - self._node_action_manager.add_action(name = node.uuid, action = Action(func=node._action_manager)) + self._node_action_manager.add_action(name=node.uuid, action=Action(func=node._action_manager)) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ @@ -210,9 +213,13 @@ class Network(SimComponent): _LOGGER.warning(f"Can't remove node {node.uuid}. It's not in the network.") return self.nodes.pop(node.uuid) + for i, _node in self._node_id_map.items(): + if node == _node: + self._node_id_map.pop(i) + break node.parent = None _LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}") - self._node_action_manager.remove_action(name = node.uuid) + self._node_action_manager.remove_action(name=node.uuid) def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None: """ @@ -237,9 +244,10 @@ class Network(SimComponent): return link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, **kwargs) self.links[link.uuid] = link + self._link_id_map[len(self.links)] = link self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname) link.parent = self - _LOGGER.info(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}") + _LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}") def remove_link(self, link: Link) -> None: """Disconnect a link from the network. @@ -250,6 +258,10 @@ class Network(SimComponent): link.endpoint_a.disconnect_link() link.endpoint_b.disconnect_link() self.links.pop(link.uuid) + for i, _link in self._link_id_map.items(): + if link == _link: + self._link_id_map.pop(i) + break link.parent = None _LOGGER.info(f"Removed link {link.uuid} from network {self.uuid}.") diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index e5f16323..24844cc3 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -4,12 +4,14 @@ import re import secrets from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.exceptions import NetworkError +from primaite.simulator import SIM_OUTPUT from primaite.simulator.core import Action, ActionManager, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem @@ -87,8 +89,6 @@ class NIC(SimComponent): "The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B" wake_on_lan: bool = False "Indicates if the NIC supports Wake-on-LAN functionality." - dns_servers: List[IPv4Address] = [] - "List of IP addresses of DNS servers used for name resolution." connected_node: Optional[Node] = None "The Node to which the NIC is connected." connected_link: Optional[Link] = None @@ -191,7 +191,7 @@ class NIC(SimComponent): if self.connected_node: self.connected_node.sys_log.info(f"NIC {self} disabled") else: - _LOGGER.info(f"NIC {self} disabled") + _LOGGER.debug(f"NIC {self} disabled") if self.connected_link: self.connected_link.endpoint_down() @@ -213,7 +213,7 @@ class NIC(SimComponent): # TODO: Inform the Node that a link has been connected self.connected_link = link self.enable() - _LOGGER.info(f"NIC {self} connected to Link {link}") + _LOGGER.debug(f"NIC {self} connected to Link {link}") def disconnect_link(self): """Disconnect the NIC from the connected Link.""" @@ -356,7 +356,7 @@ class SwitchPort(SimComponent): if self.connected_node: self.connected_node.sys_log.info(f"SwitchPort {self} disabled") else: - _LOGGER.info(f"SwitchPort {self} disabled") + _LOGGER.debug(f"SwitchPort {self} disabled") if self.connected_link: self.connected_link.endpoint_down() @@ -376,7 +376,7 @@ class SwitchPort(SimComponent): # TODO: Inform the Switch that a link has been connected self.connected_link = link - _LOGGER.info(f"SwitchPort {self} connected to Link {link}") + _LOGGER.debug(f"SwitchPort {self} connected to Link {link}") self.enable() def disconnect_link(self): @@ -411,7 +411,8 @@ class SwitchPort(SimComponent): if self.enabled: frame.decrement_ttl() self.pcap.capture(frame) - self.connected_node.forward_frame(frame=frame, incoming_port=self) + connected_node: Node = self.connected_node + connected_node.forward_frame(frame=frame, incoming_port=self) return True return False @@ -482,13 +483,13 @@ class Link(SimComponent): def endpoint_up(self): """Let the Link know and endpoint has been brought up.""" if self.is_up: - _LOGGER.info(f"Link {self} up") + _LOGGER.debug(f"Link {self} up") def endpoint_down(self): """Let the Link know and endpoint has been brought down.""" if not self.is_up: self.current_load = 0.0 - _LOGGER.info(f"Link {self} down") + _LOGGER.debug(f"Link {self} down") @property def is_up(self) -> bool: @@ -515,7 +516,7 @@ class Link(SimComponent): """ can_transmit = self._can_transmit(frame) if not can_transmit: - _LOGGER.info(f"Cannot transmit frame as {self} is at capacity") + _LOGGER.debug(f"Cannot transmit frame as {self} is at capacity") return False receiver = self.endpoint_a @@ -527,7 +528,7 @@ class Link(SimComponent): # Frame transmitted successfully # Load the frame size on the link self.current_load += frame_size - _LOGGER.info( + _LOGGER.debug( f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits " f"({self.current_load_percent})" ) @@ -886,6 +887,8 @@ class Node(SimComponent): "The NICs on the node." ethernet_port: Dict[int, NIC] = {} "The NICs on the node by port id." + dns_server: Optional[IPv4Address] = None + "List of IP addresses of DNS servers used for name resolution." accounts: Dict[str, Account] = {} "All accounts on the node." @@ -897,6 +900,8 @@ class Node(SimComponent): "All processes on the node." file_system: FileSystem "The nodes file system." + root: Path + "Root directory for simulation output." sys_log: SysLog arp: ARPCache icmp: ICMP @@ -924,14 +929,20 @@ class Node(SimComponent): kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) + if not kwargs.get("root"): + kwargs["root"] = SIM_OUTPUT / kwargs["hostname"] + if not kwargs.get("file_system"): + kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): kwargs["software_manager"] = SoftwareManager( - sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager") + sys_log=kwargs.get("sys_log"), + session_manager=kwargs.get("session_manager"), + file_system=kwargs.get("file_system"), + dns_server=kwargs.get("dns_server"), ) - if not kwargs.get("file_system"): - kwargs["file_system"] = FileSystem() super().__init__(**kwargs) self.arp.nics = self.nics + self.session_manager.software_manager = self.software_manager def _init_action_manager(self) -> ActionManager: # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will @@ -975,7 +986,25 @@ class Node(SimComponent): ) return state - def show(self, markdown: bool = False): + def show(self, markdown: bool = False, component: Literal["NIC", "OPEN_PORTS"] = "NIC"): + """A multi-use .show function that accepts either NIC or OPEN_PORTS.""" + if component == "NIC": + self._show_nic(markdown) + elif component == "OPEN_PORTS": + self._show_open_ports(markdown) + + def _show_open_ports(self, markdown: bool = False): + """Prints a table of the open ports on the Node.""" + table = PrettyTable(["Port", "Name"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.hostname} Open Ports" + for port in self.software_manager.get_open_ports(): + table.add_row([port.value, port.name]) + print(table) + + def _show_nic(self, markdown: bool = False): """Prints a table of the NICs on the Node.""" table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"]) if markdown: @@ -1066,29 +1095,30 @@ class Node(SimComponent): :param pings: The number of pings to attempt, default is 4. :return: True if the ping is successful, otherwise False. """ - if not isinstance(target_ip_address, IPv4Address): - target_ip_address = IPv4Address(target_ip_address) - if target_ip_address.is_loopback: - self.sys_log.info("Pinging loopback address") - return any(nic.enabled for nic in self.nics.values()) if self.operating_state == NodeOperatingState.ON: - self.sys_log.info(f"Pinging {target_ip_address}:") - sequence, identifier = 0, None - while sequence < pings: - sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings) - request_replies = self.icmp.request_replies.get(identifier) - passed = request_replies == pings - if request_replies: - self.icmp.request_replies.pop(identifier) - else: - request_replies = 0 - self.sys_log.info( - f"Ping statistics for {target_ip_address}: " - f"Packets: Sent = {pings}, " - f"Received = {request_replies}, " - f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)" - ) - return passed + if not isinstance(target_ip_address, IPv4Address): + target_ip_address = IPv4Address(target_ip_address) + if target_ip_address.is_loopback: + self.sys_log.info("Pinging loopback address") + return any(nic.enabled for nic in self.nics.values()) + if self.operating_state == NodeOperatingState.ON: + self.sys_log.info(f"Pinging {target_ip_address}:") + sequence, identifier = 0, None + while sequence < pings: + sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings) + request_replies = self.icmp.request_replies.get(identifier) + passed = request_replies == pings + if request_replies: + self.icmp.request_replies.pop(identifier) + else: + request_replies = 0 + self.sys_log.info( + f"Ping statistics for {target_ip_address}: " + f"Packets: Sent = {pings}, " + f"Received = {request_replies}, " + f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)" + ) + return passed return False def send_frame(self, frame: Frame): @@ -1097,7 +1127,8 @@ class Node(SimComponent): :param frame: The Frame to be sent. """ - nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address) + if self.operating_state == NodeOperatingState.ON: + nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address) nic.send_frame(frame) def receive_frame(self, frame: Frame, from_nic: NIC): @@ -1110,18 +1141,27 @@ class Node(SimComponent): :param frame: The Frame being received. :param from_nic: The NIC that received the frame. """ - if frame.ip: - if frame.ip.src_ip_address in self.arp: - self.arp.add_arp_cache_entry( - ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic - ) - if frame.ip.protocol == IPProtocol.TCP: - if frame.tcp.src_port == Port.ARP: - self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) - elif frame.ip.protocol == IPProtocol.UDP: - pass - elif frame.ip.protocol == IPProtocol.ICMP: - self.icmp.process_icmp(frame=frame, from_nic=from_nic) + if self.operating_state == NodeOperatingState.ON: + if frame.ip: + if frame.ip.src_ip_address in self.arp: + self.arp.add_arp_cache_entry( + ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic + ) + if frame.ip.protocol == IPProtocol.ICMP: + self.icmp.process_icmp(frame=frame, from_nic=from_nic) + return + # Check if the destination port is open on the Node + if frame.tcp.dst_port in self.software_manager.get_open_ports(): + # accept thr frame as the port is open + if frame.tcp.src_port == Port.ARP: + self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp) + else: + self.session_manager.receive_frame(frame) + else: + # denied as port closed + self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}") + # TODO: do we need to do anything more here? + pass def install_service(self, service: Service) -> None: """ diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index 6a50fe3f..78d2e68f 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -1,3 +1,5 @@ +from ipaddress import IPv4Address + from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import NIC from primaite.simulator.network.hardware.nodes.computer import Computer @@ -6,6 +8,11 @@ from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database_service import DatabaseService +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer +from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot def client_server_routed() -> Network: @@ -121,16 +128,33 @@ def arcd_uc2_network() -> Network: # Client 1 client_1 = Computer( - hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1" + hostname="client_1", + ip_address="192.168.10.21", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + dns_server=IPv4Address("192.168.1.10"), ) client_1.power_on() + client_1.software_manager.install(DNSClient) + client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa + client_1_dns_client_service.start() network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) + client_1.software_manager.install(DataManipulationBot) + db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;") # Client 2 client_2 = Computer( - hostname="client_2", ip_address="192.168.10.22", subnet_mask="255.255.255.0", default_gateway="192.168.10.1" + hostname="client_2", + ip_address="192.168.10.22", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + dns_server=IPv4Address("192.168.1.10"), ) client_2.power_on() + client_2.software_manager.install(DNSClient) + client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa + client_2_dns_client_service.start() network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) # Domain Controller @@ -141,14 +165,9 @@ def arcd_uc2_network() -> Network: default_gateway="192.168.1.1", ) domain_controller.power_on() - network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1]) + domain_controller.software_manager.install(DNSServer) - # Web Server - web_server = Server( - hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" - ) - web_server.power_on() - network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) + network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1]) # Database Server database_server = Server( @@ -156,13 +175,73 @@ def arcd_uc2_network() -> Network: ip_address="192.168.1.14", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) database_server.power_on() network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3]) + ddl = """ + CREATE TABLE IF NOT EXISTS user ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(50) NOT NULL, + email VARCHAR(50) NOT NULL, + age INT, + city VARCHAR(50), + occupation VARCHAR(50) + );""" + + user_insert_statements = [ + "INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa + "INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa + ] + database_server.software_manager.install(DatabaseService) + database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa + database_service.start() + database_service._process_sql(ddl, None) # noqa + for insert_statement in user_insert_statements: + database_service._process_sql(insert_statement, None) # noqa + + # Web Server + web_server = Server( + hostname="web_server", + ip_address="192.168.1.12", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), + ) + web_server.power_on() + web_server.software_manager.install(DatabaseClient) + + database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + database_client.configure(server_ip_address=IPv4Address("192.168.1.14")) + network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) + database_client.run() + database_client.connect() + + # register the web_server to a domain + dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa + dns_server_service.start() + dns_server_service.dns_register("arcd.com", web_server.ip_address) + # Backup Server backup_server = Server( - hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" + hostname="backup_server", + ip_address="192.168.1.16", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) backup_server.power_on() network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4]) @@ -173,6 +252,7 @@ def arcd_uc2_network() -> Network: ip_address="192.168.1.110", subnet_mask="255.255.255.0", default_gateway="192.168.1.1", + dns_server=IPv4Address("192.168.1.10"), ) security_suite.power_on() network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7]) @@ -183,4 +263,12 @@ def arcd_uc2_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + # Allow PostgreSQL requests + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + ) + + # Allow DNS requests + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + return network diff --git a/src/primaite/simulator/network/protocols/dns.py b/src/primaite/simulator/network/protocols/dns.py new file mode 100644 index 00000000..41bf5e0c --- /dev/null +++ b/src/primaite/simulator/network/protocols/dns.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from ipaddress import IPv4Address +from typing import Optional + +from pydantic import BaseModel + + +class DNSRequest(BaseModel): + """Represents a DNS Request packet of a network frame. + + :param domain_name_request: Domain Name Request for IP address. + """ + + domain_name_request: str + "Domain Name Request for IP address." + + +class DNSReply(BaseModel): + """Represents a DNS Reply packet of a network frame. + + :param domain_name_ip_address: IP Address of the Domain Name requested. + """ + + domain_name_ip_address: Optional[IPv4Address] = None + "IP Address of the Domain Name requested." + + +class DNSPacket(BaseModel): + """ + Represents the DNS layer of a network frame. + + :param dns_request: DNS Request packet sent by DNS Client. + :param dns_reply: DNS Reply packet generated by DNS Server. + + :Example: + + >>> dns_request = DNSPacket( + ... domain_name_request=DNSRequest(domain_name_request="www.google.co.uk"), + ... dns_reply=None + ... ) + >>> dns_response = DNSPacket( + ... dns_request=DNSRequest(domain_name_request="www.google.co.uk"), + ... dns_reply=DNSReply(domain_name_ip_address=IPv4Address("142.250.179.227")) + ... ) + """ + + dns_request: DNSRequest + "DNS Request packet sent by DNS Client." + dns_reply: Optional[DNSReply] = None + "DNS Reply packet generated by DNS Server." + + def generate_reply(self, domain_ip_address: IPv4Address) -> DNSPacket: + """Generate a new DNSPacket to be sent as a response with a DNS Reply packet which contains the IP address. + + :param domain_ip_address: The IP address that was being sought after from the original target domain name. + :return: A new instance of DNSPacket. + """ + self.dns_reply = DNSReply(domain_name_ip_address=domain_ip_address) + + return self diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index b95b4a74..d4318baf 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -59,6 +59,8 @@ class Port(Enum): "Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications." HTTPS_ALT = 8443 "Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic." + POSTGRES_SERVER = 5432 + "Postgres SQL Server." class UDPHeader(BaseModel): diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 6a07f00f..30efd5b7 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -23,9 +23,9 @@ class Application(IOSoftware): Applications are user-facing programs that may perform input/output operations. """ - operating_state: ApplicationOperatingState + operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED "The current operating state of the Application." - execution_control_status: str + execution_control_status: str = "manual" "Control status of the application's execution. It could be 'manual' or 'automatic'." num_executions: int = 0 "The number of times the application has been executed. Default is 0." @@ -53,6 +53,25 @@ class Application(IOSoftware): ) return state + def run(self) -> None: + """Open the Application.""" + if self.operating_state == ApplicationOperatingState.CLOSED: + self.sys_log.info(f"Running Application {self.name}") + self.operating_state = ApplicationOperatingState.RUNNING + + def close(self) -> None: + """Close the Application.""" + if self.operating_state == ApplicationOperatingState.RUNNING: + self.sys_log.info(f"Closed Application{self.name}") + self.operating_state = ApplicationOperatingState.CLOSED + + def install(self) -> None: + """Install Application.""" + super().install() + if self.operating_state == ApplicationOperatingState.CLOSED: + self.sys_log.info(f"Installing Application {self.name}") + self.operating_state = ApplicationOperatingState.INSTALLING + def reset_component_for_episode(self, episode: int): """ Resets the Application component for a new episode. diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py new file mode 100644 index 00000000..9d59a2f4 --- /dev/null +++ b/src/primaite/simulator/system/applications/database_client.py @@ -0,0 +1,157 @@ +from ipaddress import IPv4Address +from typing import Any, Dict, Optional +from uuid import uuid4 + +from prettytable import PrettyTable + +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application, ApplicationOperatingState +from primaite.simulator.system.core.software_manager import SoftwareManager + + +class DatabaseClient(Application): + """ + A DatabaseClient application. + + Extends the Application class to provide functionality for connecting, querying, and disconnecting from a + Database Service. It mainly operates over TCP protocol. + + :ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None. + """ + + server_ip_address: Optional[IPv4Address] = None + server_password: Optional[str] = None + connected: bool = False + _query_success_tracker: Dict[str, bool] = {} + + def __init__(self, **kwargs): + kwargs["name"] = "DatabaseClient" + kwargs["port"] = Port.POSTGRES_SERVER + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + """ + Describes the current state of the ACLRule. + + :return: A dictionary representing the current state. + """ + pass + return super().describe_state() + + def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): + """ + Configure the DatabaseClient to communicate with a DatabaseService. + + :param server_ip_address: The IP address of the Node the DatabaseService is on. + :param server_password: The password on the DatabaseService. + """ + self.server_ip_address = server_ip_address + self.server_password = server_password + self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {server_password=}.") + + def connect(self) -> bool: + """Connect to a Database Service.""" + if not self.connected and self.operating_state.RUNNING: + return self._connect(self.server_ip_address, self.server_password) + return False + + def _connect( + self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False + ) -> bool: + if is_reattempt: + if self.connected: + self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised") + self.server_ip_address = server_ip_address + return self.connected + else: + self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined") + return False + payload = {"type": "connect_request", "password": password} + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=server_ip_address, dest_port=self.port + ) + return self._connect(server_ip_address, password, True) + + def disconnect(self): + """Disconnect from the Database Service.""" + if self.connected and self.operating_state.RUNNING: + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port + ) + + self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}") + self.server_ip_address = None + self.connected = False + + def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool: + if is_reattempt: + success = self._query_success_tracker.get(query_id) + if success: + return True + return False + else: + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "sql", "sql": sql, "uuid": query_id}, + dest_ip_address=self.server_ip_address, + dest_port=self.port, + ) + return self._query(sql=sql, query_id=query_id, is_reattempt=True) + + def run(self) -> None: + """Run the DatabaseClient.""" + super().run() + self.operating_state = ApplicationOperatingState.RUNNING + self.connect() + + def query(self, sql: str) -> bool: + """ + Send a query to the Database Service. + + :param sql: The SQL query. + :return: True if the query was successful, otherwise False. + """ + if self.connected and self.operating_state.RUNNING: + query_id = str(uuid4()) + + # Initialise the tracker of this ID to False + self._query_success_tracker[query_id] = False + return self._query(sql=sql, query_id=query_id) + + def _print_data(self, data: Dict): + """ + Display the contents of the Folder in tabular format. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ + if data: + table = PrettyTable(list(data.values())[0]) + + table.align = "l" + table.title = f"{self.sys_log.hostname} Database Client" + for row in data.values(): + table.add_row(row.values()) + print(table) + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Receive a payload from the Software Manager. + + :param payload: A payload to receive. + :param session_id: The session id the payload relates to. + :return: True. + """ + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "connect_response": + self.connected = payload["response"] == True + elif payload["type"] == "sql": + query_id = payload.get("uuid") + status_code = payload.get("status_code") + self._query_success_tracker[query_id] = status_code == 200 + if self._query_success_tracker[query_id]: + self._print_data(payload["data"]) + return True diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py new file mode 100644 index 00000000..78d196b7 --- /dev/null +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -0,0 +1,54 @@ +from ipaddress import IPv4Address +from typing import Any, Dict, Optional + +from primaite.simulator.system.applications.application import Application + + +class WebBrowser(Application): + """ + Represents a web browser in the simulation environment. + + The application requests and loads web pages using its domain name and requesting IP addresses using DNS. + """ + + domain_name: str + "The domain name of the webpage." + domain_name_ip_address: Optional[IPv4Address] + "The IP address of the domain name for the webpage." + history: Dict[str] + "A dict that stores all of the previous domain names." + + def reset_component_for_episode(self, episode: int): + """ + Resets the Application component for a new episode. + + This method ensures the Application is ready for a new episode, including resetting any + stateful properties or statistics, and clearing any message queues. + """ + self.domain_name = "" + self.domain_name_ip_address = None + self.history = {} + + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Sends a payload to the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to send. + :return: True if successful, False otherwise. + """ + pass + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Receives a payload from the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to receive. + :return: True if successful, False otherwise. + """ + pass diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index c985af1f..2e5ed008 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -1,8 +1,9 @@ +import json import logging from pathlib import Path -from typing import Optional +from typing import Any, Dict, List, Optional -from primaite.simulator import TEMP_SIM_OUTPUT +from primaite.simulator import SIM_OUTPUT class _JSONFilter(logging.Filter): @@ -51,6 +52,18 @@ class PacketCapture: self.logger.addFilter(_JSONFilter()) + def read(self) -> List[Dict[str, Any]]: + """ + Read packet capture logs and return them as a list of dictionaries. + + :return: List of frames captured, represented as dictionaries. + """ + frames = [] + with open(self._get_log_path(), "r") as file: + while line := file.readline(): + frames.append(json.loads(line.rstrip())) + return frames + @property def _logger_name(self) -> str: """Get PCAP the logger name.""" @@ -62,7 +75,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = TEMP_SIM_OUTPUT / self.hostname + root = SIM_OUTPUT / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 7f3d22c5..95ece9f9 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -1,12 +1,14 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union + +from prettytable import MARKDOWN, PrettyTable from primaite.simulator.core import SimComponent -from primaite.simulator.network.transmission.data_link_layer import Frame -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame +from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader if TYPE_CHECKING: from primaite.simulator.network.hardware.base import ARPCache @@ -30,27 +32,23 @@ class Session(SimComponent): """ protocol: IPProtocol - src_ip_address: IPv4Address - dst_ip_address: IPv4Address + with_ip_address: IPv4Address src_port: Optional[Port] dst_port: Optional[Port] connected: bool = False @classmethod - def from_session_key( - cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]] - ) -> Session: + def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session: """ Create a Session instance from a session key tuple. :param session_key: Tuple containing the session details. :return: A Session instance. """ - protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key + protocol, with_ip_address, src_port, dst_port = session_key return Session( protocol=protocol, - src_ip_address=src_ip_address, - dst_ip_address=dst_ip_address, + with_ip_address=with_ip_address, src_port=src_port, dst_port=dst_port, ) @@ -97,8 +95,8 @@ class SessionManager: @staticmethod def _get_session_key( - frame: Frame, from_source: bool = True - ) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]: + frame: Frame, inbound_frame: bool = True + ) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]: """ Extracts the session key from the given frame. @@ -110,32 +108,39 @@ class SessionManager: - Optional[Port]: The destination port number (if applicable). :param frame: The network frame from which to extract the session key. - :param from_source: A flag to indicate if the key should be extracted from the source or destination. :return: A tuple containing the session key. """ protocol = frame.ip.protocol - src_ip_address = frame.ip.src_ip_address - dst_ip_address = frame.ip.dst_ip_address + with_ip_address = frame.ip.src_ip_address if protocol == IPProtocol.TCP: - if from_source: + if inbound_frame: src_port = frame.tcp.src_port dst_port = frame.tcp.dst_port else: dst_port = frame.tcp.src_port src_port = frame.tcp.dst_port + with_ip_address = frame.ip.dst_ip_address elif protocol == IPProtocol.UDP: - if from_source: + if inbound_frame: src_port = frame.udp.src_port dst_port = frame.udp.dst_port else: dst_port = frame.udp.src_port src_port = frame.udp.dst_port + with_ip_address = frame.ip.dst_ip_address else: src_port = None dst_port = None - return protocol, src_ip_address, dst_ip_address, src_port, dst_port + return protocol, with_ip_address, src_port, dst_port - def receive_payload_from_software_manager(self, payload: Any, session_id: Optional[int] = None): + def receive_payload_from_software_manager( + self, + payload: Any, + dst_ip_address: Optional[IPv4Address] = None, + dst_port: Optional[Port] = None, + session_id: Optional[str] = None, + is_reattempt: bool = False, + ) -> Union[Any, None]: """ Receive a payload from the SoftwareManager. @@ -144,46 +149,87 @@ class SessionManager: :param payload: The payload to be sent. :param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created. """ - # TODO: Implement session creation and + if session_id: + session = self.sessions_by_uuid[session_id] + dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address + dst_port = self.sessions_by_uuid[session_id].dst_port - self.send_payload_to_nic(payload, session_id) + dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address) - def send_payload_to_software_manager(self, payload: Any, session_id: int): + if dst_mac_address: + outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address) + else: + if not is_reattempt: + self.arp_cache.send_arp_request(dst_ip_address) + return self.receive_payload_from_software_manager( + payload=payload, + dst_ip_address=dst_ip_address, + dst_port=dst_port, + session_id=session_id, + is_reattempt=True, + ) + else: + return + + frame = Frame( + ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address), + ip=IPPacket( + src_ip_address=outbound_nic.ip_address, + dst_ip_address=dst_ip_address, + ), + tcp=TCPHeader( + src_port=dst_port, + dst_port=dst_port, + ), + payload=payload, + ) + + if not session_id: + session_key = self._get_session_key(frame, inbound_frame=False) + session = self.sessions_by_key.get(session_key) + if not session: + # Create new session + session = Session.from_session_key(session_key) + self.sessions_by_key[session_key] = session + self.sessions_by_uuid[session.uuid] = session + + outbound_nic.send_frame(frame) + + def receive_frame(self, frame: Frame): """ - Send a payload to the software manager. - - :param payload: The payload to be sent. - :param session_id: The Session ID the payload originates from. - """ - self.software_manager.receive_payload_from_session_manger() - - def send_payload_to_nic(self, payload: Any, session_id: int): - """ - Send a payload across the Network. - - Takes a payload and a session_id. Builds a Frame and sends it across the network via a NIC. - - :param payload: The payload to be sent. - :param session_id: The Session ID the payload originates from - """ - # TODO: Implement frame construction and sent to NIC. - pass - - def receive_payload_from_nic(self, frame: Frame): - """ - Receive a Frame from the NIC. + Receive a Frame. Extract the session key using the _get_session_key method, and forward the payload to the appropriate session. If the session does not exist, a new one is created. :param frame: The frame being received. """ - session_key = self._get_session_key(frame) - session = self.sessions_by_key.get(session_key) + session_key = self._get_session_key(frame, inbound_frame=True) + session: Session = self.sessions_by_key.get(session_key) if not session: # Create new session session = Session.from_session_key(session_key) self.sessions_by_key[session_key] = session self.sessions_by_uuid[session.uuid] = session - self.software_manager.receive_payload_from_session_manger(payload=frame, session=session) - # TODO: Implement the frame deconstruction and send to SoftwareManager. + self.software_manager.receive_payload_from_session_manager( + payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid + ) + + def show(self, markdown: bool = False): + """ + Print tables describing the SessionManager. + + Generate and print PrettyTable instances that show details about + session's destination IP Address, destination Ports and the protocol to use. + Output can be in Markdown format. + + :param markdown: Use Markdown style in table output. Defaults to False. + """ + table = PrettyTable(["Destination IP", "Port", "Protocol"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} Session Manager" + for session in self.sessions_by_key.values(): + table.add_row([session.dst_ip_address, session.dst_port.value, session.protocol.name]) + print(table) diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 411fb6e9..99445bf8 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -1,99 +1,162 @@ -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union +from ipaddress import IPv4Address +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from prettytable import MARKDOWN, PrettyTable + +from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application -from primaite.simulator.system.core.session_manager import Session +from primaite.simulator.system.applications.application import Application, ApplicationOperatingState from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.services.service import Service -from primaite.simulator.system.software import SoftwareType +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import IOSoftware if TYPE_CHECKING: from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog +from typing import Type, TypeVar + +IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) + class SoftwareManager: """A class that manages all running Services and Applications on a Node and facilitates their communication.""" - def __init__(self, session_manager: "SessionManager", sys_log: "SysLog"): + def __init__( + self, + session_manager: "SessionManager", + sys_log: SysLog, + file_system: FileSystem, + dns_server: Optional[IPv4Address], + ): """ Initialize a new instance of SoftwareManager. :param session_manager: The session manager handling network communications. """ self.session_manager = session_manager - self.services: Dict[str, Service] = {} - self.applications: Dict[str, Application] = {} + self.software: Dict[str, Union[Service, Application]] = {} + self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {} self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log + self.file_system: FileSystem = file_system + self.dns_server: Optional[IPv4Address] = dns_server - def add_service(self, name: str, service: Service, port: Port, protocol: IPProtocol): + def get_open_ports(self) -> List[Port]: """ - Add a Service to the manager. + Get a list of open ports. - :param name: The name of the service. - :param service: The service instance. - :param port: The port used by the service. - :param protocol: The network protocol used by the service. + :return: A list of all open ports on the Node. """ - service.software_manager = self - self.services[name] = service - self.port_protocol_mapping[(port, protocol)] = service + open_ports = [Port.ARP] + for software in self.port_protocol_mapping.values(): + if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}: + open_ports.append(software.port) + open_ports.sort(key=lambda port: port.value) + return open_ports - def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol): + def install(self, software_class: Type[IOSoftwareClass]): """ - Add an Application to the manager. + Install an Application or Service. - :param name: The name of the application. - :param application: The application instance. - :param port: The port used by the application. - :param protocol: The network protocol used by the application. + :param software_class: The software class. """ - application.software_manager = self - self.applications[name] = application - self.port_protocol_mapping[(port, protocol)] = application + if software_class in self._software_class_to_name_map: + self.sys_log.info(f"Cannot install {software_class} as it is already installed") + return + software = software_class( + software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server + ) + if isinstance(software, Application): + software.install() + software.software_manager = self + self.software[software.name] = software + self.port_protocol_mapping[(software.port, software.protocol)] = software + self.sys_log.info(f"Installed {software.name}") + if isinstance(software, Application): + software.operating_state = ApplicationOperatingState.CLOSED - def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any): + def uninstall(self, software_name: str): + """ + Uninstall an Application or Service. + + :param software_name: The software name. + """ + if software_name in self.software: + software = self.software.pop(software_name) # noqa + del software + self.sys_log.info(f"Deleted {software_name}") + return + self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") + + def send_internal_payload(self, target_software: str, payload: Any): """ Send a payload to a specific service or application. :param target_software: The name of the target service or application. - :param target_software_type: The type of software (Service, Application, Process). :param payload: The data to be sent. - :param receiver_type: The type of the target, either 'service' or 'application'. """ - if target_software_type is SoftwareType.SERVICE: - receiver = self.services.get(target_software) - elif target_software_type is SoftwareType.APPLICATION: - receiver = self.applications.get(target_software) - else: - raise ValueError(f"Invalid receiver type {target_software_type}") + receiver = self.software.get(target_software) if receiver: receiver.receive_payload(payload) else: - raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}") + self.sys_log.error(f"No Service of Application found with the name {target_software}") - def send_payload_to_session_manger(self, payload: Any, session_id: Optional[int] = None): + def send_payload_to_session_manager( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + dest_port: Optional[Port] = None, + session_id: Optional[str] = None, + ): """ Send a payload to the SessionManager. :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. :param session_id: The Session ID the payload is to originate from. Optional. """ - self.session_manager.receive_payload_from_software_manager(payload, session_id) + self.session_manager.receive_payload_from_software_manager( + payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id + ) - def receive_payload_from_session_manger(self, payload: Any, session: Session): + def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str): """ Receive a payload from the SessionManager and forward it to the corresponding service or application. :param payload: The payload being received. :param session: The transport session the payload originates from. """ - # receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) - # if receiver: - # receiver.receive_payload(None, payload) - # else: - # raise ValueError(f"No service or application found for port {port} and protocol {protocol}") + receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None) + if receiver: + receiver.receive(payload=payload, session_id=session_id) + else: + self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}") pass + + def show(self, markdown: bool = False): + """ + Prints a table of the SwitchPorts on the Switch. + + :param markdown: If True, outputs the table in markdown format. Default is False. + """ + table = PrettyTable(["Name", "Type", "Operating State", "Health State", "Port"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} Software Manager" + for software in self.port_protocol_mapping.values(): + software_type = "Service" if isinstance(software, Service) else "Application" + table.add_row( + [ + software.name, + software_type, + software.operating_state.name, + software.health_state_actual.name, + software.port.value, + ] + ) + print(table) diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index e07c28aa..791e0be8 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -3,7 +3,7 @@ from pathlib import Path from prettytable import MARKDOWN, PrettyTable -from primaite.simulator import TEMP_SIM_OUTPUT +from primaite.simulator import SIM_OUTPUT class _NotJSONFilter(logging.Filter): @@ -81,7 +81,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = TEMP_SIM_OUTPUT / self.hostname + root = SIM_OUTPUT / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log" diff --git a/src/primaite/simulator/system/services/database.py b/src/primaite/simulator/system/services/database.py deleted file mode 100644 index 23b856f7..00000000 --- a/src/primaite/simulator/system/services/database.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Dict - -from primaite.simulator.file_system.file_system_file_type import FileSystemFileType -from primaite.simulator.network.hardware.base import Node -from primaite.simulator.system.services.service import Service - - -class DatabaseService(Service): - """Service loosely modelled on Microsoft SQL Server.""" - - def describe_state(self) -> Dict: - """ - Produce a dictionary describing the current state of this object. - - Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. - - :return: Current state of this object and child objects. - :rtype: Dict - """ - return super().describe_state() - - def uninstall(self) -> None: - """ - Undo installation procedure. - - This method deletes files created when installing the database, and the database folder if it is empty. - """ - super().uninstall() - node: Node = self.parent - node.file_system.delete_file(self.primary_store) - node.file_system.delete_file(self.transaction_log) - if self.secondary_store: - node.file_system.delete_file(self.secondary_store) - if len(self.folder.files) == 0: - node.file_system.delete_folder(self.folder) - - def install(self) -> None: - """Perform first time install on a node, creating necessary files.""" - super().install() - assert isinstance(self.parent, Node), "Database install can only happen after the db service is added to a node" - self._setup_files() - - def _setup_files( - self, - db_size: int = 1000, - use_secondary_db_file: bool = False, - secondary_db_size: int = 300, - folder_name: str = "database", - ): - """Set up files that are required by the database on the parent host. - - :param db_size: Initial file size of the main database file, defaults to 1000 - :type db_size: int, optional - :param use_secondary_db_file: Whether to use a secondary database file, defaults to False - :type use_secondary_db_file: bool, optional - :param secondary_db_size: Size of the secondary db file, defaults to None - :type secondary_db_size: int, optional - :param folder_name: Name of the folder which will be setup to hold the db files, defaults to "database" - :type folder_name: str, optional - """ - # note that this parent.file_system.create_folder call in the future will be authenticated by using permissions - # handler. This permission will be granted based on service account given to the database service. - self.parent: Node - self.folder = self.parent.file_system.create_folder(folder_name) - self.primary_store = self.parent.file_system.create_file( - "db_primary_store", db_size, FileSystemFileType.MDF, folder=self.folder - ) - self.transaction_log = self.parent.file_system.create_file( - "db_transaction_log", "1", FileSystemFileType.LDF, folder=self.folder - ) - if use_secondary_db_file: - self.secondary_store = self.parent.file_system.create_file( - "db_secondary_store", secondary_db_size, FileSystemFileType.NDF, folder=self.folder - ) - else: - self.secondary_store = None diff --git a/src/primaite/simulator/system/services/database_service.py b/src/primaite/simulator/system/services/database_service.py new file mode 100644 index 00000000..62120fc7 --- /dev/null +++ b/src/primaite/simulator/system/services/database_service.py @@ -0,0 +1,155 @@ +import sqlite3 +from datetime import datetime +from sqlite3 import OperationalError +from typing import Any, Dict, List, Optional, Union + +from prettytable import MARKDOWN, PrettyTable + +from primaite.simulator.file_system.file_system import File +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState + + +class DatabaseService(Service): + """ + A class for simulating a generic SQL Server service. + + This class inherits from the `Service` class and provides methods to manage and query a SQLite database. + """ + + password: Optional[str] = None + connections: Dict[str, datetime] = {} + + def __init__(self, **kwargs): + kwargs["name"] = "DatabaseService" + kwargs["port"] = Port.POSTGRES_SERVER + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + self._db_file: File + self._create_db_file() + self._conn = sqlite3.connect(self._db_file.sim_path) + self._cursor = self._conn.cursor() + + def tables(self) -> List[str]: + """ + Get a list of table names present in the database. + + :return: List of table names. + """ + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';" + results = self._process_sql(sql) + return [row[0] for row in results["data"]] + + def show(self, markdown: bool = False): + """ + Prints a list of table names in the database using PrettyTable. + + :param markdown: Whether to output the table in Markdown format. + """ + table = PrettyTable(["Table"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.file_system.sys_log.hostname} Database" + for row in self.tables(): + table.add_row([row]) + print(table) + + def _create_db_file(self): + """Creates the Simulation File and sqlite file in the file system.""" + self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True) + self.folder = self._db_file.folder + + def _process_connect( + self, session_id: str, password: Optional[str] = None + ) -> Dict[str, Union[int, Dict[str, bool]]]: + status_code = 500 # Default internal server error + if self.operating_state == ServiceOperatingState.RUNNING: + status_code = 503 # service unavailable + if self.health_state_actual == SoftwareHealthState.GOOD: + if self.password == password: + status_code = 200 # ok + self.connections[session_id] = datetime.now() + self.sys_log.info(f"Connect request for {session_id=} authorised") + else: + status_code = 401 # Unauthorised + self.sys_log.info(f"Connect request for {session_id=} declined") + else: + status_code = 404 # service not found + return {"status_code": status_code, "type": "connect_response", "response": status_code == 200} + + def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]: + """ + Executes the given SQL query and returns the result. + + :param query: The SQL query to be executed. + :return: Dictionary containing status code and data fetched. + """ + self.sys_log.info(f"{self.name}: Running {query}") + try: + self._cursor.execute(query) + self._conn.commit() + except OperationalError: + # Handle the case where the table does not exist. + self.sys_log.error(f"{self.name}: Error, query failed") + return {"status_code": 404, "data": {}} + data = [] + description = self._cursor.description + if description: + headers = [] + for header in description: + headers.append(header[0]) + data = self._cursor.fetchall() + if data and headers: + data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data} + return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id} + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + return super().describe_state() + + def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Processes the incoming SQL payload and sends the result back. + + :param payload: The SQL query to be executed. + :param session_id: The session identifier. + :return: True if the Status Code is 200, otherwise False. + """ + result = {"status_code": 500, "data": []} + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "connect_request": + result = self._process_connect(session_id=session_id, password=payload.get("password")) + elif payload["type"] == "disconnect": + if session_id in self.connections: + self.connections.pop(session_id) + elif payload["type"] == "sql": + if session_id in self.connections: + result = self._process_sql(query=payload["sql"], query_id=payload["uuid"]) + else: + result = {"status_code": 401, "type": "sql"} + self.send(payload=result, session_id=session_id) + return True + + def send(self, payload: Any, session_id: str, **kwargs) -> bool: + """ + Send a SQL response back down to the SessionManager. + + :param payload: The SQL query results. + :param session_id: The session identifier. + :return: True if the Status Code is 200, otherwise False. + """ + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) + + return payload["status_code"] == 200 diff --git a/src/primaite/simulator/system/services/dns_client.py b/src/primaite/simulator/system/services/dns_client.py new file mode 100644 index 00000000..cf5278af --- /dev/null +++ b/src/primaite/simulator/system/services/dns_client.py @@ -0,0 +1,154 @@ +from ipaddress import IPv4Address +from typing import Dict, Optional + +from primaite import getLogger +from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.simulator.system.services.service import Service + +_LOGGER = getLogger(__name__) + + +class DNSClient(Service): + """Represents a DNS Client as a Service.""" + + dns_cache: Dict[str, IPv4Address] = {} + "A dict of known mappings between domain/URLs names and IPv4 addresses." + dns_server: Optional[IPv4Address] = None + "The DNS Server the client sends requests to." + + def __init__(self, **kwargs): + kwargs["name"] = "DNSClient" + kwargs["port"] = Port.DNS + # DNS uses UDP by default + # it switches to TCP when the bytes exceed 512 (or 4096) bytes + # TCP for now + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + """ + Describes the current state of the software. + + The specifics of the software's state, including its health, criticality, + and any other pertinent information, should be implemented in subclasses. + + :return: A dictionary containing key-value pairs representing the current state of the software. + :rtype: Dict + """ + state = super().describe_state() + return state + + def reset_component_for_episode(self, episode: int): + """ + Resets the Service component for a new episode. + + This method ensures the Service is ready for a new episode, including resetting any + stateful properties or statistics, and clearing any message queues. + """ + pass + + def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address): + """ + Adds a domain name to the DNS Client cache. + + :param: domain_name: The domain name to save to cache + :param: ip_address: The IP Address to attach the domain name to + """ + self.dns_cache[domain_name] = ip_address + + def check_domain_exists( + self, + target_domain: str, + session_id: Optional[str] = None, + is_reattempt: bool = False, + ) -> bool: + """Function to check if domain name exists. + + :param: target_domain: The domain requested for an IP address. + :param: session_id: The Session ID the payload is to originate from. Optional. + :param: is_reattempt: Checks if the request has been reattempted. Default is False. + """ + # check if the target domain is in the client's DNS cache + payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain)) + + # check if the domain is already in the DNS cache + if target_domain in self.dns_cache: + self.sys_log.info( + f"DNS Client: Domain lookup for {target_domain} successful, resolves to {self.dns_cache[target_domain]}" + ) + return True + else: + # return False if already reattempted + if is_reattempt: + self.sys_log.info(f"DNS Client: Domain lookup for {target_domain} failed") + return False + else: + # send a request to check if domain name exists in the DNS Server + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS + ) + + # recursively re-call the function passing is_reattempt=True + return self.check_domain_exists( + target_domain=target_domain, + session_id=session_id, + is_reattempt=True, + ) + + def send( + self, + payload: DNSPacket, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Sends a payload to the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to be sent. + :param dest_ip_address: The ip address of the payload destination. + :param dest_port: The port of the payload destination. + :param session_id: The Session ID the payload is to originate from. Optional. + + :return: True if successful, False otherwise. + """ + # create DNS request packet + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) + return True + + def receive( + self, + payload: DNSPacket, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Receives a payload from the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param payload: The payload to be sent. + :param session_id: The Session ID the payload is to originate from. Optional. + :return: True if successful, False otherwise. + """ + # The payload should be a DNS packet + if not isinstance(payload, DNSPacket): + _LOGGER.debug(f"{payload} is not a DNSPacket") + return False + # cast payload into a DNS packet + payload: DNSPacket = payload + if payload.dns_reply is not None: + # add the IP address to the client cache + if payload.dns_reply.domain_name_ip_address: + self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address + return True + + return False diff --git a/src/primaite/simulator/system/services/dns_server.py b/src/primaite/simulator/system/services/dns_server.py new file mode 100644 index 00000000..c6a9afd3 --- /dev/null +++ b/src/primaite/simulator/system/services/dns_server.py @@ -0,0 +1,122 @@ +from ipaddress import IPv4Address +from typing import Any, Dict, Optional + +from prettytable import MARKDOWN, PrettyTable + +from primaite import getLogger +from primaite.simulator.network.protocols.dns import DNSPacket +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import Service + +_LOGGER = getLogger(__name__) + + +class DNSServer(Service): + """Represents a DNS Server as a Service.""" + + dns_table: Dict[str, IPv4Address] = {} + "A dict of mappings between domain names and IPv4 addresses." + + def __init__(self, **kwargs): + kwargs["name"] = "DNSServer" + kwargs["port"] = Port.DNS + # DNS uses UDP by default + # it switches to TCP when the bytes exceed 512 (or 4096) bytes + # TCP for now + kwargs["protocol"] = IPProtocol.TCP + super().__init__(**kwargs) + + def describe_state(self) -> Dict: + """ + Describes the current state of the software. + + The specifics of the software's state, including its health, criticality, + and any other pertinent information, should be implemented in subclasses. + + :return: A dictionary containing key-value pairs representing the current state of the software. + :rtype: Dict + """ + state = super().describe_state() + return state + + def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]: + """ + Attempts to find the IP address for a domain name. + + :param target_domain: The single domain name requested by a DNS client. + :return ip_address: The IP address of that domain name or None. + """ + return self.dns_table.get(target_domain) + + def dns_register(self, domain_name: str, domain_ip_address: IPv4Address): + """ + Register a domain name and its IP address. + + :param: domain_name: The domain name to register + :type: domain_name: str + + :param: domain_ip_address: The IP address that the domain should route to + :type: domain_ip_address: IPv4Address + """ + self.dns_table[domain_name] = domain_ip_address + + def reset_component_for_episode(self, episode: int): + """ + Resets the Service component for a new episode. + + This method ensures the Service is ready for a new episode, including resetting any + stateful properties or statistics, and clearing any message queues. + """ + pass + + def receive( + self, + payload: Any, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: + """ + Receives a payload from the SessionManager. + + The specifics of how the payload is processed and whether a response payload + is generated should be implemented in subclasses. + + :param: payload: The payload to send. + :param: session_id: The id of the session. Optional. + + :return: True if DNS request returns a valid IP, otherwise, False + """ + # The payload should be a DNS packet + if not isinstance(payload, DNSPacket): + _LOGGER.debug(f"{payload} is not a DNSPacket") + return False + # cast payload into a DNS packet + payload: DNSPacket = payload + if payload.dns_request is not None: + self.sys_log.info( + f"DNS Server: Received domain lookup request for {payload.dns_request.domain_name_request} " + f"from session {session_id}" + ) + # generate a reply with the correct DNS IP address + payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request)) + self.sys_log.info( + f"DNS Server: Responding to domain lookup request for {payload.dns_request.domain_name_request} " + f"with ip address: {payload.dns_reply.domain_name_ip_address}" + ) + # send reply + self.send(payload, session_id) + return payload.dns_reply.domain_name_ip_address is not None + + return False + + def show(self, markdown: bool = False): + """Prints a table of DNS Lookup table.""" + table = PrettyTable(["Domain Name", "IP Address"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} DNS Lookup table" + for dns in self.dns_table.items(): + table.add_row([dns[0], dns[1]]) + print(table) diff --git a/src/primaite/simulator/system/services/red_services/__init__.py b/src/primaite/simulator/system/services/red_services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py new file mode 100644 index 00000000..30643b32 --- /dev/null +++ b/src/primaite/simulator/system/services/red_services/data_manipulation_bot.py @@ -0,0 +1,49 @@ +from ipaddress import IPv4Address +from typing import Optional + +from primaite.simulator.system.applications.database_client import DatabaseClient + + +class DataManipulationBot(DatabaseClient): + """ + Red Agent Data Integration Service. + + The Service represents a bot that causes files/folders in the File System to + become corrupted. + """ + + server_ip_address: Optional[IPv4Address] = None + payload: Optional[str] = None + server_password: Optional[str] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.name = "DataManipulationBot" + + def configure( + self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None + ): + """ + Configure the DataManipulatorBot to communicate with a DatabaseService. + + :param server_ip_address: The IP address of the Node the DatabaseService is on. + :param server_password: The password on the DatabaseService. + :param payload: The data manipulation query payload. + """ + self.server_ip_address = server_ip_address + self.payload = payload + self.server_password = server_password + self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}.") + + def run(self): + """Run the DataManipulationBot.""" + if self.server_ip_address and self.payload: + self.sys_log.info(f"Attempting to start the {self.name}") + super().run() + if not self.connected: + self.connect() + if self.connected: + self.query(self.payload) + self.sys_log.info(f"{self.name} payload delivered: {self.payload}") + else: + self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_io_address and payload.") diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index f9cc784d..20b92027 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -1,4 +1,3 @@ -from abc import abstractmethod from enum import Enum from typing import Any, Dict, Optional @@ -33,7 +32,7 @@ class Service(IOSoftware): Services are programs that run in the background and may perform input/output operations. """ - operating_state: ServiceOperatingState + operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED "The current operating state of the Service." restart_duration: int = 5 "How many timesteps does it take to restart this service." @@ -51,7 +50,6 @@ class Service(IOSoftware): am.add_action("enable", Action(func=lambda request, context: self.enable())) return am - @abstractmethod def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -74,77 +72,85 @@ class Service(IOSoftware): """ pass - def send(self, payload: Any, session_id: str, **kwargs) -> bool: + def send( + self, + payload: Any, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Sends a payload to the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to send. + :param: payload: The payload to send. + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - pass + self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id) - def receive(self, payload: Any, session_id: str, **kwargs) -> bool: + def receive( + self, + payload: Any, + session_id: Optional[str] = None, + **kwargs, + ) -> bool: """ Receives a payload from the SessionManager. The specifics of how the payload is processed and whether a response payload is generated should be implemented in subclasses. - :param payload: The payload to receive. + :param: payload: The payload to send. + :param: session_id: The id of the session + :return: True if successful, False otherwise. """ - pass + + pass def stop(self) -> None: """Stop the service.""" - _LOGGER.debug(f"Stopping service {self.name}") if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: - self.parent.sys_log.info(f"Stopping service {self.name}") + self.sys_log.info(f"Stopping service {self.name}") self.operating_state = ServiceOperatingState.STOPPED - def start(self) -> None: + def start(self, **kwargs) -> None: """Start the service.""" - _LOGGER.debug(f"Starting service {self.name}") if self.operating_state == ServiceOperatingState.STOPPED: - self.parent.sys_log.info(f"Starting service {self.name}") + self.sys_log.info(f"Starting service {self.name}") self.operating_state = ServiceOperatingState.RUNNING def pause(self) -> None: """Pause the service.""" - _LOGGER.debug(f"Pausing service {self.name}") if self.operating_state == ServiceOperatingState.RUNNING: - self.parent.sys_log.info(f"Pausing service {self.name}") + self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.PAUSED def resume(self) -> None: """Resume paused service.""" - _LOGGER.debug(f"Resuming service {self.name}") if self.operating_state == ServiceOperatingState.PAUSED: - self.parent.sys_log.info(f"Resuming service {self.name}") + self.sys_log.info(f"Resuming service {self.name}") self.operating_state = ServiceOperatingState.RUNNING def restart(self) -> None: """Restart running service.""" - _LOGGER.debug(f"Restarting service {self.name}") if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: - self.parent.sys_log.info(f"Pausing service {self.name}") + self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.RESTARTING self.restart_countdown = self.restarting_duration def disable(self) -> None: """Disable the service.""" - _LOGGER.debug(f"Disabling service {self.name}") - self.parent.sys_log.info(f"Disabling Application {self.name}") + self.sys_log.info(f"Disabling Application {self.name}") self.operating_state = ServiceOperatingState.DISABLED def enable(self) -> None: """Enable the disabled service.""" - _LOGGER.debug(f"Enabling service {self.name}") if self.operating_state == ServiceOperatingState.DISABLED: - self.parent.sys_log.info(f"Enabling Application {self.name}") + self.sys_log.info(f"Enabling Application {self.name}") self.operating_state = ServiceOperatingState.STOPPED def apply_timestep(self, timestep: int) -> None: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 605a062b..70c1bbf2 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -1,9 +1,11 @@ from abc import abstractmethod from enum import Enum -from typing import Any, Dict, Set +from typing import Any, Dict, Optional from primaite.simulator.core import Action, ActionManager, SimComponent +from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.sys_log import SysLog class SoftwareType(Enum): @@ -62,11 +64,11 @@ class Software(SimComponent): name: str "The name of the software." - health_state_actual: SoftwareHealthState + health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD "The actual health state of the software." - health_state_visible: SoftwareHealthState + health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD "The health state of the software visible to the red agent." - criticality: SoftwareCriticality + criticality: SoftwareCriticality = SoftwareCriticality.LOWEST "The criticality level of the software." patching_count: int = 0 "The count of patches applied to the software, defaults to 0." @@ -74,6 +76,14 @@ class Software(SimComponent): "The count of times the software has been scanned, defaults to 0." revealed_to_red: bool = False "Indicates if the software has been revealed to red agent, defaults is False." + software_manager: Any = None + "An instance of Software Manager that is used by the parent node." + sys_log: SysLog = None + "An instance of SysLog that is used by the parent node." + file_system: FileSystem + "The FileSystem of the Node the Software is installed on." + folder: Optional[Folder] = None + "The folder on the file system the Software uses." def _init_action_manager(self) -> ActionManager: am = super()._init_action_manager() @@ -132,7 +142,6 @@ class Software(SimComponent): """ self.health_state_actual = health_state - @abstractmethod def install(self) -> None: """ Perform first-time setup of this service on a node. @@ -175,8 +184,8 @@ class IOSoftware(Software): "Indicates if the software uses TCP protocol for communication. Default is True." udp: bool = True "Indicates if the software uses UDP protocol for communication. Default is True." - ports: Set[Port] - "The set of ports to which the software is connected." + port: Port + "The port to which the software is connected." @abstractmethod def describe_state(self) -> Dict: @@ -212,7 +221,6 @@ class IOSoftware(Software): :param kwargs: Additional keyword arguments specific to the implementation. :return: True if the payload was successfully sent, False otherwise. """ - pass def receive(self, payload: Any, session_id: str, **kwargs) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index f1c05187..35548f2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ import pytest from primaite import getLogger from primaite.environment.primaite_env import Primaite from primaite.primaite_session import PrimaiteSession +from primaite.simulator.network.container import Network +from primaite.simulator.network.networks import arcd_uc2_network from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 @@ -19,7 +21,22 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) +# PrimAITE v3 stuff +from primaite.simulator.file_system.file_system import FileSystem +from primaite.simulator.network.hardware.base import Node + +@pytest.fixture(scope="function") +def uc2_network() -> Network: + return arcd_uc2_network() + + +@pytest.fixture(scope="function") +def file_system() -> FileSystem: + return Node(hostname="fs_node").file_system + + +# PrimAITE v2 stuff class TempPrimaiteSession(PrimaiteSession): """ A temporary PrimaiteSession class. diff --git a/tests/e2e_integration_tests/__init__.py b/tests/e2e_integration_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py new file mode 100644 index 00000000..a859e5ff --- /dev/null +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -0,0 +1,25 @@ +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database_service import DatabaseService +from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot + + +def test_data_manipulation(uc2_network): + client_1: Computer = uc2_network.get_node_by_hostname("client_1") + db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + + database_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = database_server.software_manager.software["DatabaseService"] + + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + + # First check that the DB client on the web_server can successfully query the users table on the database + assert db_client.query("SELECT * FROM user;") + + # Now we run the DataManipulationBot + db_manipulation_bot.run() + + # Now check that the DB client on the web_server cannot query the users table on the database + assert not db_client.query("SELECT * FROM user;") diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index 73d19339..2a77a31b 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -1,56 +1,59 @@ -from primaite.simulator.network.hardware.base import Node -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.database import DatabaseService -from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState +from ipaddress import IPv4Address + +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database_service import DatabaseService -def test_installing_database(): - db = DatabaseService( - name="SQL-database", - health_state_actual=SoftwareHealthState.GOOD, - health_state_visible=SoftwareHealthState.GOOD, - criticality=SoftwareCriticality.MEDIUM, - ports=[ - Port.SQL_SERVER, - ], - operating_state=ServiceOperatingState.RUNNING, - ) +def test_database_client_server_connection(uc2_network): + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] - node = Node(hostname="db-server") + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] - node.install_service(db) + assert len(db_service.connections) == 1 - assert db in node - - file_exists = False - for folder in node.file_system.folders.values(): - for file in folder.files.values(): - if file.name == "db_primary_store": - file_exists = True - break - if file_exists: - break - assert file_exists + db_client.disconnect() + assert len(db_service.connections) == 0 -def test_uninstalling_database(): - db = DatabaseService( - name="SQL-database", - health_state_actual=SoftwareHealthState.GOOD, - health_state_visible=SoftwareHealthState.GOOD, - criticality=SoftwareCriticality.MEDIUM, - ports=[ - Port.SQL_SERVER, - ], - operating_state=ServiceOperatingState.RUNNING, - ) +def test_database_client_server_correct_password(uc2_network): + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] - node = Node(hostname="db-server") + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] - node.install_service(db) + db_client.disconnect() - node.uninstall_service(db) + db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345") + db_service.password = "12345" - assert db not in node - assert node.file_system.get_folder_by_name("database") is None + assert db_client.connect() + + assert len(db_service.connections) == 1 + + +def test_database_client_server_incorrect_password(uc2_network): + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + + db_server: Server = uc2_network.get_node_by_hostname("database_server") + db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + + db_client.disconnect() + db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321") + db_service.password = "12345" + + assert not db_client.connect() + assert len(db_service.connections) == 0 + + +def test_database_client_query(uc2_network): + """Tests DB query across the network returns HTTP status 200 and date.""" + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client.connect() + + assert db_client.query("SELECT * FROM user;") diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py new file mode 100644 index 00000000..640c268a --- /dev/null +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -0,0 +1,28 @@ +from ipaddress import IPv4Address + +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer +from primaite.simulator.system.services.service import ServiceOperatingState + + +def test_dns_client_server(uc2_network): + client_1: Computer = uc2_network.get_node_by_hostname("client_1") + domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller") + + dns_client: DNSClient = client_1.software_manager.software["DNSClient"] + dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"] + + assert dns_client.operating_state == ServiceOperatingState.RUNNING + assert dns_server.operating_state == ServiceOperatingState.RUNNING + + dns_server.show() + + # fake domain should not be added to dns cache + assert not dns_client.check_domain_exists(target_domain="fake-domain.com") + assert dns_client.dns_cache.get("fake-domain.com", None) is None + + # arcd.com is registered in dns server and should be saved to cache + assert dns_client.check_domain_exists(target_domain="arcd.com") + assert dns_client.dns_cache.get("arcd.com", None) is not None diff --git a/tests/test_seeding_and_deterministic_session.py b/tests/test_seeding_and_deterministic_session.py index 9500c4a3..aff5496a 100644 --- a/tests/test_seeding_and_deterministic_session.py +++ b/tests/test_seeding_and_deterministic_session.py @@ -45,7 +45,6 @@ def test_seeded_learning(temp_primaite_session): ), "Expected output is based upon a agent that was trained with seed 67890" session.learn() actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict() - print(actual_mean_reward_per_episode, "THISt") assert actual_mean_reward_per_episode == expected_mean_reward_per_episode diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 348eb440..d1d78003 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,132 +1,144 @@ +import pytest + from primaite.simulator.file_system.file_system import FileSystem -from primaite.simulator.file_system.file_system_file import FileSystemFile -from primaite.simulator.file_system.file_system_folder import FileSystemFolder +from primaite.simulator.file_system.file_type import FileType -def test_create_folder_and_file(): +def test_create_folder_and_file(file_system): """Test creating a folder and a file.""" - file_system = FileSystem() - folder = file_system.create_folder(folder_name="test_folder") - assert len(file_system.folders) is 1 + assert len(file_system.folders) == 1 + file_system.create_folder(folder_name="test_folder") - file = file_system.create_file(file_name="test_file", size=10, folder_uuid=folder.uuid) - assert len(file_system.get_folder_by_id(folder.uuid).files) is 1 + assert len(file_system.folders) is 2 + file_system.create_file(file_name="test_file.txt", folder_name="test_folder") - assert file_system.get_file_by_id(file.uuid).name is "test_file" - assert file_system.get_file_by_id(file.uuid).size == 10 + assert len(file_system.get_folder("test_folder").files) == 1 + + assert file_system.get_folder("test_folder").get_file("test_file.txt") -def test_create_file(): +def test_create_file_no_folder(file_system): """Tests that creating a file without a folder creates a folder and sets that as the file's parent.""" - file_system = FileSystem() - - file = file_system.create_file(file_name="test_file", size=10) + file = file_system.create_file(file_name="test_file.txt", size=10) assert len(file_system.folders) is 1 - assert file_system.get_folder_by_name("root").get_file_by_id(file.uuid) is file + assert file_system.get_folder("root").get_file("test_file.txt") == file + assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT + assert file_system.get_folder("root").get_file("test_file.txt").size == 10 -def test_delete_file(): +def test_create_file_no_extension(file_system): + """Tests that creating a file without an extension sets the file type to FileType.UNKNOWN.""" + file = file_system.create_file(file_name="test_file") + assert len(file_system.folders) is 1 + assert file_system.get_folder("root").get_file("test_file") == file + assert file_system.get_folder("root").get_file("test_file").file_type == FileType.UNKNOWN + assert file_system.get_folder("root").get_file("test_file").size == 0 + + +def test_delete_file(file_system): """Tests that a file can be deleted.""" - file_system = FileSystem() + file_system.create_file(file_name="test_file.txt") + assert len(file_system.folders) == 1 + assert len(file_system.get_folder("root").files) == 1 - file = file_system.create_file(file_name="test_file", size=10) - assert len(file_system.folders) is 1 - - folder_id = list(file_system.folders.keys())[0] - folder = file_system.get_folder_by_id(folder_id) - assert folder.get_file_by_id(file.uuid) is file - - file_system.delete_file(file=file) - assert len(file_system.folders) is 1 - assert len(folder.files) is 0 + file_system.delete_file(folder_name="root", file_name="test_file.txt") + assert len(file_system.folders) == 1 + assert len(file_system.get_folder("root").files) == 0 -def test_delete_non_existent_file(): +def test_delete_non_existent_file(file_system): """Tests deleting a non existent file.""" - file_system = FileSystem() - - file = file_system.create_file(file_name="test_file", size=10) - not_added_file = FileSystemFile(name="not_added") + file_system.create_file(file_name="test_file.txt") # folder should be created - assert len(file_system.folders) is 1 + assert len(file_system.folders) == 1 # should only have 1 file in the file system - folder_id = list(file_system.folders.keys())[0] - folder = file_system.get_folder_by_id(folder_id) - assert len(list(folder.files)) is 1 - - assert folder.get_file_by_id(file.uuid) is file + assert len(file_system.get_folder("root").files) == 1 # deleting should not change how many files are in folder - file_system.delete_file(file=not_added_file) - assert len(file_system.folders) is 1 - assert len(list(folder.files)) is 1 + file_system.delete_file(folder_name="root", file_name="does_not_exist!") + + # should still only be one folder + assert len(file_system.folders) == 1 + # The folder should still have 1 file + assert len(file_system.get_folder("root").files) == 1 -def test_delete_folder(): - file_system = FileSystem() - folder = file_system.create_folder(folder_name="test_folder") - assert len(file_system.folders) is 1 +def test_delete_folder(file_system): + file_system.create_folder(folder_name="test_folder") + assert len(file_system.folders) == 2 - file_system.delete_folder(folder) - assert len(file_system.folders) is 0 + file_system.delete_folder(folder_name="test_folder") + assert len(file_system.folders) == 1 -def test_deleting_a_non_existent_folder(): - file_system = FileSystem() - folder = file_system.create_folder(folder_name="test_folder") - not_added_folder = FileSystemFolder(name="fake_folder") - assert len(file_system.folders) is 1 +def test_deleting_a_non_existent_folder(file_system): + file_system.create_folder(folder_name="test_folder") + assert len(file_system.folders) == 2 - file_system.delete_folder(not_added_folder) - assert len(file_system.folders) is 1 + file_system.delete_folder(folder_name="does not exist!") + assert len(file_system.folders) == 2 -def test_move_file(): +def test_deleting_root_folder_fails(file_system): + assert len(file_system.folders) == 1 + + file_system.delete_folder(folder_name="root") + assert len(file_system.folders) == 1 + + +def test_move_file(file_system): """Tests the file move function.""" - file_system = FileSystem() - src_folder = file_system.create_folder(folder_name="test_folder_1") - assert len(file_system.folders) is 1 + file_system.create_folder(folder_name="src_folder") + file_system.create_folder(folder_name="dst_folder") - target_folder = file_system.create_folder(folder_name="test_folder_2") - assert len(file_system.folders) is 2 + file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder") + original_uuid = file.uuid - file = file_system.create_file(file_name="test_file", size=10, folder_uuid=src_folder.uuid) - assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1 - assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 0 + assert len(file_system.get_folder("src_folder").files) == 1 + assert len(file_system.get_folder("dst_folder").files) == 0 - file_system.move_file(file=file, src_folder=src_folder, target_folder=target_folder) + file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") - assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 0 - assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 1 + assert len(file_system.get_folder("src_folder").files) == 0 + assert len(file_system.get_folder("dst_folder").files) == 1 + assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid -def test_copy_file(): +def test_copy_file(file_system): """Tests the file copy function.""" - file_system = FileSystem() - src_folder = file_system.create_folder(folder_name="test_folder_1") - assert len(file_system.folders) is 1 + file_system.create_folder(folder_name="src_folder") + file_system.create_folder(folder_name="dst_folder") - target_folder = file_system.create_folder(folder_name="test_folder_2") - assert len(file_system.folders) is 2 + file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True) + original_uuid = file.uuid - file = file_system.create_file(file_name="test_file", size=10, folder_uuid=src_folder.uuid) - assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1 - assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 0 + assert len(file_system.get_folder("src_folder").files) == 1 + assert len(file_system.get_folder("dst_folder").files) == 0 - file_system.copy_file(file=file, src_folder=src_folder, target_folder=target_folder) + file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") - assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1 - assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 1 + assert len(file_system.get_folder("src_folder").files) == 1 + assert len(file_system.get_folder("dst_folder").files) == 1 + assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid -def test_serialisation(): +def test_folder_quarantine_state(file_system): + """Tests the changing of folder quarantine status.""" + folder = file_system.get_folder("root") + + assert folder.quarantine_status() is False + + folder.quarantine() + assert folder.quarantine_status() is True + + folder.unquarantine() + assert folder.quarantine_status() is False + + +@pytest.mark.skip(reason="Skipping until we tackle serialisation") +def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" - file_system = FileSystem() - folder = file_system.create_folder(folder_name="test_folder") - assert len(file_system.folders) is 1 - - file_system.create_file(file_name="test_file", size=10, folder_uuid=folder.uuid) - assert file_system.get_folder_by_id(folder.uuid) is folder + file_system.create_file(file_name="test_file.txt") serialised_file_sys = file_system.model_dump_json() deserialised_file_sys = FileSystem.model_validate_json(serialised_file_sys) diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py deleted file mode 100644 index 629b9bb9..00000000 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_file.py +++ /dev/null @@ -1,23 +0,0 @@ -from primaite.simulator.file_system.file_system_file import FileSystemFile -from primaite.simulator.file_system.file_system_file_type import FileSystemFileType - - -def test_file_type(): - """Tests tha the FileSystemFile type is set correctly.""" - file = FileSystemFile(name="test", file_type=FileSystemFileType.DOC) - assert file.file_type is FileSystemFileType.DOC - - -def test_get_size(): - """Tests that the file size is being returned properly.""" - file = FileSystemFile(name="test", size=1.5) - assert file.size == 1.5 - - -def test_serialisation(): - """Test to check that the object serialisation works correctly.""" - file = FileSystemFile(name="test", size=1.5, file_type=FileSystemFileType.DOC) - serialised_file = file.model_dump_json() - deserialised_file = FileSystemFile.model_validate_json(serialised_file) - - assert file.model_dump_json() == deserialised_file.model_dump_json() diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py deleted file mode 100644 index 1940e886..00000000 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system_folder.py +++ /dev/null @@ -1,75 +0,0 @@ -from primaite.simulator.file_system.file_system_file import FileSystemFile -from primaite.simulator.file_system.file_system_file_type import FileSystemFileType -from primaite.simulator.file_system.file_system_folder import FileSystemFolder - - -def test_adding_removing_file(): - """Test the adding and removing of a file from a folder.""" - folder = FileSystemFolder(name="test") - - file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC) - - folder.add_file(file) - assert folder.size == 10 - assert len(folder.files) is 1 - - folder.remove_file(file) - assert folder.size == 0 - assert len(folder.files) is 0 - - -def test_remove_non_existent_file(): - """Test the removing of a file that does not exist.""" - folder = FileSystemFolder(name="test") - - file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC) - not_added_file = FileSystemFile(name="fake_file", size=10, file_type=FileSystemFileType.DOC) - - folder.add_file(file) - assert folder.size == 10 - assert len(folder.files) is 1 - - folder.remove_file(not_added_file) - assert folder.size == 10 - assert len(folder.files) is 1 - - -def test_get_file_by_id(): - """Test to make sure that the correct file is returned.""" - folder = FileSystemFolder(name="test") - - file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC) - file2 = FileSystemFile(name="test_file_2", size=10, file_type=FileSystemFileType.DOC) - - folder.add_file(file) - folder.add_file(file2) - assert folder.size == 20 - assert len(folder.files) is 2 - - assert folder.get_file_by_id(file_id=file.uuid) is file - - -def test_folder_quarantine_state(): - """Tests the changing of folder quarantine status.""" - folder = FileSystemFolder(name="test") - - assert folder.quarantine_status() is False - - folder.quarantine() - assert folder.quarantine_status() is True - - folder.end_quarantine() - assert folder.quarantine_status() is False - - -def test_serialisation(): - """Test to check that the object serialisation works correctly.""" - folder = FileSystemFolder(name="test") - file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC) - folder.add_file(file) - - serialised_folder = folder.model_dump_json() - - deserialised_folder = FileSystemFolder.model_validate_json(serialised_folder) - - assert folder.model_dump_json() == deserialised_folder.model_dump_json() diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 290e7cc3..66bd59a9 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -1,5 +1,7 @@ import json +import pytest + from primaite.simulator.network.container import Network @@ -10,6 +12,7 @@ def test_creating_container(): assert net.links == {} +@pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_describe_state(): """Check that we can describe network state without raising errors, and that the result is JSON serialisable.""" net = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/__init__.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py new file mode 100644 index 00000000..dd785cc1 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -0,0 +1,20 @@ +from ipaddress import IPv4Address + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.networks import arcd_uc2_network +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot + + +def test_creation(): + network = arcd_uc2_network() + + client_1: Node = network.get_node_by_hostname("client_1") + + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + + assert data_manipulation_bot.name == "DataManipulationBot" + assert data_manipulation_bot.port == Port.POSTGRES_SERVER + assert data_manipulation_bot.protocol == IPProtocol.TCP + assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index ea5c1b83..d41c63c7 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -1,17 +1,18 @@ -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.database import DatabaseService -from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState +import json + +import pytest + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.system.services.database_service import DatabaseService -def test_creation(): - db = DatabaseService( - name="SQL-database", - health_state_actual=SoftwareHealthState.GOOD, - health_state_visible=SoftwareHealthState.GOOD, - criticality=SoftwareCriticality.MEDIUM, - ports=[ - Port.SQL_SERVER, - ], - operating_state=ServiceOperatingState.RUNNING, - ) +@pytest.fixture(scope="function") +def database_server() -> Node: + node = Node(hostname="db_node") + node.software_manager.install(DatabaseService) + node.software_manager.software["DatabaseService"].start() + return node + + +def test_creation(database_server): + database_server.software_manager.show() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py new file mode 100644 index 00000000..b4f20539 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py @@ -0,0 +1,100 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.dns_client import DNSClient +from primaite.simulator.system.services.dns_server import DNSServer + + +@pytest.fixture(scope="function") +def dns_server() -> Node: + node = Node(hostname="dns_server") + node.software_manager.install(software_class=DNSServer) + node.software_manager.software["DNSServer"].start() + return node + + +@pytest.fixture(scope="function") +def dns_client() -> Node: + node = Node(hostname="dns_client") + node.software_manager.install(software_class=DNSClient) + node.software_manager.software["DNSClient"].start() + return node + + +def test_create_dns_server(dns_server): + assert dns_server is not None + dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] + assert dns_server_service.name is "DNSServer" + assert dns_server_service.port is Port.DNS + assert dns_server_service.protocol is IPProtocol.TCP + + +def test_create_dns_client(dns_client): + assert dns_client is not None + dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + assert dns_client_service.name is "DNSClient" + assert dns_client_service.port is Port.DNS + assert dns_client_service.protocol is IPProtocol.TCP + + +def test_dns_server_domain_name_registration(dns_server): + """Test to check if the domain name registration works.""" + dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + # return none for an unknown domain + assert dns_server_service.dns_lookup("fake-domain.com") is None + assert dns_server_service.dns_lookup("real-domain.com") is not None + + +def test_dns_client_check_domain_in_cache(dns_client): + """Test to make sure that the check_domain_in_cache returns the correct values.""" + dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + + # add a domain to the dns client cache + dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12")) + + assert dns_client_service.check_domain_exists("fake-domain.com") is False + assert dns_client_service.check_domain_exists("real-domain.com") is True + + +def test_dns_server_receive(dns_server): + """Test to make sure that the DNS Server correctly responds to a DNS Client request.""" + dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + assert ( + dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com"))) + is False + ) + + assert ( + dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com"))) + is True + ) + + dns_server_service.show() + + +def test_dns_client_receive(dns_client): + """Test to make sure the DNS Client knows how to deal with request responses.""" + dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + + dns_client_service.receive( + payload=DNSPacket( + dns_request=DNSRequest(domain_name_request="real-domain.com"), + dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")), + ) + ) + + # domain name should be saved to cache + assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")