Merged PR 178: Database Client/Server Simulation

## Summary
This pull request focuses on implementing key functionalities needed for network frame processing and database interactions. The primary changes are:

1. **Internal Frame Processing:** The logic has been implemented in various components like `NIC`, `Node`, `SessionManager`, and `SoftwareManager`. These changes enable the system to process incoming and outgoing network frames in a structured manner.
2. **Database Service and Client:** The `DatabaseService` simulates a SQL database server, while the `DatabaseClient` provides a client interface for connecting to this service. These functionalities have been built and integrated into the existing architecture.
3. **Networking and Communication:** Tests have been added to confirm that database queries can be sent over the network, demonstrating end-to-end functionality.

## Commits

- #1816 Simplified a bunch of stuff in the file system in prep for services and applications. Started adding the database logic. Waiting for the software manager/session manager work from another tick. Merge branch 'dev' into feature/1816_Database-Service-(Network-and-User-Interaction)
- #1816 Added the final pieces of the puzzle to get data up from NIC → session manager → software manager → service.
- #1816 DatabaseService now uses the send function when responding.
- #1816 Added database client. Installed the database client on the Web Server node in the UC2 network. Updated the integration test to query the DB server using the DB client.
- #1816 Added full documentation on the database client/server, and the internal frame processing process
- #1816 Fixed tests. Used node and link added number (id) in observation space.

## Test process
For testing these functionalities, the following steps were taken:

1. **Unit Tests:** Tests have been written to confirm that database queries can be sent over the network successfully.
2. **Integration Tests:** Manually tested the frame processing flow from NIC to Service/Application, ensuring the functionality behaves as expected.
3. **Database Queries** Executed sample SQL queries using the `DatabaseClient` to make sure it interacts correctly with the `DatabaseService`.

## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have written/updated **design docs** if this PR implements new functionality
- [ ] I have update the **change log**
- [ ] I have run **pre-commit** checks for code style

Related work items: #1816
This commit is contained in:
Christopher McCarthy
2023-09-12 13:31:08 +00:00
committed by Czar Echavez
43 changed files with 1775 additions and 1030 deletions

View File

@@ -98,6 +98,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/getting_started
source/about
source/config
source/simulation
source/primaite_session
source/custom_agent
PrimAITE API <source/_autosummary/primaite>

View File

@@ -21,3 +21,5 @@ Contents
simulation_components/network/router
simulation_components/network/switch
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/software

View File

@@ -2,7 +2,7 @@
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _about:
.. _network:
Network
=======

View File

@@ -2,7 +2,7 @@
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _about:
.. _router:
Router Module
=============

View File

@@ -0,0 +1,58 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
DataManipulationBot
===================
The ``DataManipulationBot`` class provides functionality to connect to a ``DatabaseService`` and execute malicious SQL statements.
Overview
--------
The bot is intended to simulate a malicious actor carrying out attacks like:
- Dropping tables
- Deleting records
- Modifying data
On a database server by abusing an application's trusted database connectivity.
Usage
-----
- Create an instance and call ``configure`` to set:
- Target database server IP
- Database password (if needed)
- SQL statement payload
- Call ``run`` to connect and execute the statement.
The bot handles connecting, executing the statement, and disconnecting.
Example
-------
.. code-block:: python
client_1 = Computer(
hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
)
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
data_manipulation_bot.run()
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.
Implementation
--------------
The bot extends ``DatabaseClient`` and leverages its connectivity.
- Uses the Application base class for lifecycle management.
- Credentials and target IP set via ``configure``.
- ``run`` handles connecting, executing statement, and disconnecting.
- SQL payload executed via ``query`` method.
- Results in malicious SQL being executed on remote database server.

View File

@@ -0,0 +1,70 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Database Client Server
======================
Database Service
----------------
The ``DatabaseService`` provides a SQL database server simulation by extending the base Service class.
Key capabilities
^^^^^^^^^^^^^^^^
- Initialises a SQLite database file in the ``Node``'s ``FileSystem`` upon creation.
- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs.
- Authenticates connections using a configurable password.
- Executes SQL queries against the SQLite database.
- Returns query results and status codes back to clients.
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Clients connect, execute queries, and disconnect.
- Service runs on TCP port 5432 by default.
Implementation
^^^^^^^^^^^^^^
- Uses SQLite for persistent storage.
- Creates the database file within the node's file system.
- Manages client connections in a dictionary by session ID.
- Processes SQL queries via the SQLite cursor and connection.
- Returns results and status codes in a standard dictionary format.
- Extends Service class for integration with ``SoftwareManager``.
Database Client
---------------
The DatabaseClient provides a client interface for connecting to the ``DatabaseService``.
Key features
^^^^^^^^^^^^
- Connects to the ``DatabaseService`` via the ``SoftwareManager``.
- Executes SQL queries and retrieves result sets.
- Handles connecting, querying, and disconnecting.
- Provides a simple ``query`` method for running SQL.
Usage
^^^^^
- Initialise with server IP address and optional password.
- Connect to the ``DatabaseService`` with ``connect``.
- Execute SQL queries via ``query``.
- Retrieve results in a dictionary.
- Disconnect when finished.
Implementation
^^^^^^^^^^^^^^
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Connect and disconnect methods manage sessions.
- Provides easy interface for applications to query database.
- Payloads serialised as dictionaries for transmission.
- Extends base Application class.

View File

@@ -0,0 +1,98 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _internal_frame_processing:
Internal Frame Processing
=========================
Inbound
-------
At the NIC
^^^^^^^^^^
When a Frame is received on the Node's NIC:
- The NIC checks if it is enabled. If so, it will process the Frame.
- The Frame's received timestamp is set.
- The Frame is captured by the NIC's PacketCapture if configured.
- The NIC decrements the IP Packet's TTL by 1.
- The NIC calls the Node's ``receive_frame`` method, passing itself as the receiving NIC and the Frame.
At the Node
^^^^^^^^^^^
When ``receive_frame`` is called on the Node:
- The source IP address is added to the ARP cache if not already present.
- The Frame's protocol is checked:
- If ARP or ICMP, the Frame is passed to that protocol's handler method.
- Otherwise it is passed to the SessionManager's ``receive_frame`` method.
At the SessionManager
^^^^^^^^^^^^^^^^^^^^^
When ``receive_frame`` is called on the SessionManager:
- It extracts the key session details from the Frame:
- Protocol (TCP, UDP, etc)
- Source IP
- Destination IP
- Source Port
- Destination Port
- It checks if an existing Session matches these details.
- If no match, a new Session is created to represent this exchange.
- The payload and new/existing Session ID are passed to the SoftwareManager's ``receive_payload_from_session_manager`` method.
At the SoftwareManager
^^^^^^^^^^^^^^^^^^^^^^
Inside ``receive_payload_from_session_manager``:
- The SoftwareManager checks its port/protocol mapping to find which Service or Application is listening on the destination port and protocol.
- The payload and Session ID are forwarded to that receiver Service/Application instance via their ``receive`` method.
- The Service/Application can then process the payload as needed.
Outbound
--------
At the Service/Application
^^^^^^^^^^^^^^^^^^^^^^^^^^
When a Service or Application needs to send a payload:
- It calls the SoftwareManager's ``send_payload_to_session_manager`` method.
- Passes the payload, and either destination IP and destination port for new payloads, or session id for existing sessions.
At the SoftwareManager
^^^^^^^^^^^^^^^^^^^^^^
Inside ``send_payload_to_session_manager``:
- The SoftwareManager forwards the payload and details through to to the SessionManager's ``receive_payload_from_software_manager`` method.
At the SessionManager
^^^^^^^^^^^^^^^^^^^^^
When ``receive_payload_from_software_manager`` is called:
- If a Session ID was provided, it looks up the Session.
- Gets the destination MAC address by checking the ARP cache.
- If no Session ID was provided, the destination Port, IP address and Mac Address are used along with the outbound IP Address and Mac Address to create a new Session.
- Calls `send_payload_to_nic`` to construct and send the Frame.
When ``send_payload_to_nic`` is called:
- It constructs a new Frame with the payload, using the source NIC's MAC, source IP, destination MAC, etc.
- The outbound NIC is looked up via the ARP cache based on destination IP.
- The constructed Frame is passed to the outbound NIC's ``send_frame`` method.
At the NIC
^^^^^^^^^^
When ``send_frame`` is called:
- The NIC checks if it is enabled before sending.
- If enabled, it sends the Frame out to the connected Link.

View File

@@ -0,0 +1,19 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Software
========
Contents
########
.. toctree::
:maxdepth: 8
database_client_server
data_manipulation_bot

View File

@@ -1,5 +1,14 @@
from datetime import datetime
from primaite import _PRIMAITE_ROOT
TEMP_SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output"
SIM_OUTPUT = None
"A path at the repo root dir to use temporarily for sim output testing while in dev."
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
if not SIM_OUTPUT:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)

View File

@@ -149,8 +149,8 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.file_system.file_system_file_type import FileSystemFileType\n",
"from primaite.simulator.file_system.file_system_file import FileSystemFile"
"from primaite.simulator.file_system.file_type import FileType\n",
"from primaite.simulator.file_system.file_system import File"
]
},
{
@@ -160,7 +160,7 @@
"outputs": [],
"source": [
"my_pc_downloads_folder = my_pc.file_system.create_folder(\"downloads\")\n",
"my_pc_downloads_folder.add_file(FileSystemFile(name=\"firefox_installer.zip\",file_type=FileSystemFileType.ZIP))"
"my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",file_type=FileType.ZIP))"
]
},
{
@@ -171,7 +171,7 @@
{
"data": {
"text/plain": [
"FileSystemFile(uuid='7d56a563-ecc0-4011-8c97-240dd6c885c0', name='favicon.ico', size=40.0, file_type=<FileSystemFileType.PNG: '11'>, action_manager=None)"
"File(uuid='7d56a563-ecc0-4011-8c97-240dd6c885c0', name='favicon.ico', size=40.0, file_type=<FileType.PNG: '11'>, action_manager=None)"
]
},
"execution_count": 9,
@@ -181,7 +181,7 @@
],
"source": [
"my_server_folder = my_server.file_system.create_folder(\"static\")\n",
"my_server.file_system.create_file(\"favicon.ico\", file_type=FileSystemFileType.PNG)"
"my_server.file_system.create_file(\"favicon.ico\", file_type=FileType.PNG)"
]
},
{

View File

@@ -1,242 +1,519 @@
from random import choice
from __future__ import annotations
import math
import os.path
import shutil
from pathlib import Path
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import SimComponent
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.file_system.file_system_folder import FileSystemFolder
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
from primaite.simulator.system.core.sys_log import SysLog
_LOGGER = getLogger(__name__)
class FileSystem(SimComponent):
"""Class that contains all the simulation File System."""
def convert_size(size_bytes: int) -> str:
"""
Convert a file size from bytes to a string with a more human-readable format.
folders: Dict[str, FileSystemFolder] = {}
"""List containing all the folders in the file system."""
This function takes the size of a file in bytes and converts it to a string representation with appropriate size
units (B, KB, MB, GB, etc.).
:param size_bytes: The size of the file in bytes.
:return: The human-readable string representation of the file size.
"""
if size_bytes == 0:
return "0 B"
# Tuple of size units starting from Bytes up to Yottabytes
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
# Calculate the index (i) that will be used to select the appropriate size unit from size_name
i = int(math.floor(math.log(size_bytes, 1024)))
# Calculate the adjusted size value (s) in terms of the new size unit
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{s} {size_name[i]}"
class FileSystemItemABC(SimComponent):
"""
Abstract base class for file system items used in the file system simulation.
:ivar name: The name of the FileSystemItemABC.
"""
name: str
"The name of the FileSystemItemABC."
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update({"folders": {uuid: folder.describe_state() for uuid, folder in self.folders.items()}})
state.update(
{
"name": self.name,
}
)
return state
def get_folders(self) -> Dict:
"""Returns the list of folders."""
return self.folders
@property
def size_str(self) -> str:
"""
Get the file size in a human-readable string format.
This property makes use of the :func:`convert_size` function to convert the `self.size` attribute to a string
that is easier to read and understand.
:return: The human-readable string representation of the file size.
"""
return convert_size(self.size)
class FileSystem(SimComponent):
"""Class that contains all the simulation File System."""
folders: Dict[str, Folder] = {}
"List containing all the folders in the file system."
_folders_by_name: Dict[str, Folder] = {}
sys_log: SysLog
sim_root: Path
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Ensure a default root folder
if not self.folders:
self.create_folder("root")
@property
def size(self) -> int:
"""
Calculate and return the total size of all folders in the file system.
:return: The sum of the sizes of all folders in the file system.
"""
return sum(folder.size for folder in self.folders.values())
def show(self, markdown: bool = False, full: bool = False):
"""
Prints a table of the FileSystem, displaying either just folders or full files.
:param markdown: Flag indicating if output should be in markdown format.
:param full: Flag indicating if to show full files.
"""
headers = ["Folder", "Size"]
if full:
headers[0] = "File Path"
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} File System"
for folder in self.folders.values():
if not full:
table.add_row([folder.name, folder.size_str])
else:
for file in folder.files.values():
table.add_row([file.path, file.size_str])
if full:
print(table.get_string(sortby="File Path"))
else:
print(table.get_string(sortby="Folder"))
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()}
return state
def create_folder(self, folder_name: str) -> Folder:
"""
Creates a Folder and adds it to the list of folders.
:param folder_name: The name of the folder.
"""
# check if folder with name already exists
if self.get_folder(folder_name):
raise Exception(f"Cannot create folder as it already exists: {folder_name}")
folder = Folder(name=folder_name, fs=self)
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self.sys_log.info(f"Created folder /{folder.name}")
return folder
def delete_folder(self, folder_name: str):
"""
Deletes a folder, removes it from the folders list and removes any child folders and files.
:param folder_name: The name of the folder.
"""
if folder_name == "root":
self.sys_log.warning("Cannot delete the root folder.")
return
folder = self._folders_by_name.get(folder_name)
if folder:
for file in folder.files.values():
self.delete_file(file)
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
else:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
def create_file(
self,
file_name: str,
size: Optional[float] = None,
file_type: Optional[FileSystemFileType] = None,
folder: Optional[FileSystemFolder] = None,
folder_uuid: Optional[str] = None,
) -> FileSystemFile:
size: Optional[int] = None,
file_type: Optional[FileType] = None,
folder_name: Optional[str] = None,
real: bool = False,
) -> File:
"""
Creates a FileSystemFile and adds it to the list of files.
Creates a File and adds it to the list of files.
If no size or file_type are provided, one will be chosen randomly.
If no folder_uuid or folder is provided, a new folder will be created.
:param: file_name: The file name
:type: file_name: str
:param: size: The size the file takes on disk.
:type: size: Optional[float]
:param: file_type: The type of the file
:type: Optional[FileSystemFileType]
:param: folder: The folder to add the file to
:type: folder: Optional[FileSystemFolder]
:param: folder_uuid: The uuid of the folder to add the file to
:type: folder_uuid: Optional[str]
:param file_name: The file name.
:param size: The size the file takes on disk in bytes.
:param file_type: The type of the file.
:param folder_name: The folder to add the file to.
:param real: "Indicates whether the File is actually a real file in the Node sim fs output."
"""
file = None
folder = None
if file_type is None:
file_type = self.get_random_file_type()
# if no folder uuid provided, create a folder and add file to it
if folder_uuid is not None:
# otherwise check for existence and add file
folder = self.get_folder_by_id(folder_uuid)
if folder is not None:
if folder_name:
# check if file with name already exists
if folder.get_file_by_name(file_name):
raise Exception(f'File with name "{file_name}" already exists.')
file = FileSystemFile(name=file_name, size=size, file_type=file_type)
folder.add_file(file=file)
folder = self._folders_by_name.get(folder_name)
# If not then create it
if not folder:
folder = self.create_folder(folder_name)
else:
# check if a "root" folder exists
folder = self.get_folder_by_name("root")
if folder is None:
# create a root folder
folder = FileSystemFolder(name="root")
# Use root folder if folder_name not supplied
folder = self._folders_by_name["root"]
# add file to root folder
file = FileSystemFile(name=file_name, size=size, file_type=file_type)
folder.add_file(file)
self.folders[folder.uuid] = folder
# Create the file and add it to the folder
file = File(
name=file_name,
sim_size=size,
file_type=file_type,
folder=folder,
real=real,
sim_path=self.sim_root if real else None,
)
folder.add_file(file)
self.sys_log.info(f"Created file /{file.path}")
return file
def create_folder(
self,
folder_name: str,
) -> FileSystemFolder:
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
"""
Creates a FileSystemFolder and adds it to the list of folders.
Retrieve a file by its name from a specific folder.
:param: folder_name: The name of the folder
:type: folder_name: str
:param folder_name: The name of the folder where the file resides.
:param file_name: The name of the file to be retrieved, including its extension.
:return: An instance of File if it exists, otherwise `None`.
"""
# check if folder with name already exists
if self.get_folder_by_name(folder_name):
raise Exception(f'Folder with name "{folder_name}" already exists.')
folder = self.get_folder(folder_name)
if folder:
return folder.get_file(file_name)
self.fs.sys_log.info(f"file not found /{folder_name}/{file_name}")
folder = FileSystemFolder(name=folder_name)
self.folders[folder.uuid] = folder
return folder
def delete_file(self, file: Optional[FileSystemFile] = None):
def delete_file(self, folder_name: str, file_name: str):
"""
Deletes a file and removes it from the files list.
Delete a file by its name from a specific folder.
:param file: The file to delete
:type file: Optional[FileSystemFile]
:param folder_name: The name of the folder containing the file.
:param file_name: The name of the file to be deleted, including its extension.
"""
# iterate through folders to delete the item with the matching uuid
for key in self.folders:
self.get_folder_by_id(key).remove_file(file)
folder = self.get_folder(folder_name)
if folder:
file = folder.get_file(file_name)
if file:
folder.remove_file(file)
self.sys_log.info(f"Deleted file /{file.path}")
def delete_folder(self, folder: FileSystemFolder):
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
"""
Deletes a folder, removes it from the folders list and removes any child folders and files.
Move a file from one folder to another.
:param folder: The folder to remove
:type folder: FileSystemFolder
:param src_folder_name: The name of the source folder containing the file.
:param src_file_name: The name of the file to be moved.
:param dst_folder_name: The name of the destination folder.
"""
if folder is None or not isinstance(folder, FileSystemFolder):
raise Exception(f"Invalid folder: {folder}")
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
src_folder = file.folder
if self.folders.get(folder.uuid):
del self.folders[folder.uuid]
# remove file from src
src_folder.remove_file(file)
dst_folder = self.get_folder(folder_name=dst_folder_name)
if not dst_folder:
dst_folder = self.create_folder(dst_folder_name)
# add file to dst
dst_folder.add_file(file)
if file.real:
old_sim_path = file.sim_path
file.sim_path = file.folder.fs.sim_root / file.path
file.sim_path.parent.mkdir(exist_ok=True)
shutil.move(old_sim_path, file.sim_path)
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
"""
Copy a file from one folder to another.
:param src_folder_name: The name of the source folder containing the file.
:param src_file_name: The name of the file to be copied.
:param dst_folder_name: The name of the destination folder.
"""
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
dst_folder = self.get_folder(folder_name=dst_folder_name)
if not dst_folder:
dst_folder = self.create_folder(dst_folder_name)
new_file = file.make_copy(dst_folder=dst_folder)
dst_folder.add_file(new_file)
if file.real:
new_file.sim_path.parent.mkdir(exist_ok=True)
shutil.copy2(file.sim_path, new_file.sim_path)
def get_folder(self, folder_name: str) -> Optional[Folder]:
"""
Get a folder by its name if it exists.
:param folder_name: The folder name.
:return: The matching Folder.
"""
return self._folders_by_name.get(folder_name)
def get_folder_by_id(self, folder_uuid: str) -> Optional[Folder]:
"""
Get a folder by its uuid if it exists.
:param folder_uuid: The folder uuid.
:return: The matching Folder.
"""
return self.folders.get(folder_uuid)
class Folder(FileSystemItemABC):
"""Simulation Folder."""
fs: FileSystem
"The FileSystem the Folder is in."
files: Dict[str, File] = {}
"Files stored in the folder."
_files_by_name: Dict[str, File] = {}
"Files by their name as <file name>.<file type>."
is_quarantined: bool = False
"Flag that marks the folder as quarantined if true."
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
state["is_quarantined"] = self.is_quarantined
return state
def show(self, markdown: bool = False):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["File", "Size"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.fs.sys_log.hostname} File System Folder ({self.name})"
for file in self.files.values():
table.add_row([file.name, file.size_str])
print(table.get_string(sortby="File"))
@property
def size(self) -> int:
"""
Calculate and return the total size of all files in the folder.
:return: The total size of all files in the folder. If no files exist or all have `None`
size, returns 0.
"""
return sum(file.size for file in self.files.values() if file.size is not None)
def get_file(self, file_name: str) -> Optional[File]:
"""
Get a file by its name.
File name must be the filename and prefix, like 'memo.docx'.
:param file_name: The file name.
:return: The matching File.
"""
# TODO: Increment read count?
return self._files_by_name.get(file_name)
def get_file_by_id(self, file_uuid: str) -> File:
"""
Get a file by its uuid.
:param file_uuid: The file uuid.
:return: The matching File.
"""
return self.files.get(file_uuid)
def add_file(self, file: File):
"""
Adds a file to the folder.
:param File file: The File object to be added to the folder.
:raises Exception: If the provided `file` parameter is None or not an instance of the
`File` class.
"""
if file is None or not isinstance(file, File):
raise Exception(f"Invalid file: {file}")
# check if file with id already exists in folder
if file.uuid in self.files:
_LOGGER.debug(f"File with id {file.uuid} already exists in folder")
else:
_LOGGER.debug(f"File with UUID {folder.uuid} was not found.")
# add to list
self.files[file.uuid] = file
self._files_by_name[file.name] = file
file.folder = self
def move_file(self, file: FileSystemFile, src_folder: FileSystemFolder, target_folder: FileSystemFolder):
def remove_file(self, file: Optional[File]):
"""
Moves a file from one folder to another.
Removes a file from the folder list.
can provide
The method can take a File object or a file id.
:param: file: The file to move
:type: file: FileSystemFile
:param: src_folder: The folder where the file is located
:type: FileSystemFolder
:param: target_folder: The folder where the file should be moved to
:type: FileSystemFolder
:param file: The file to remove
"""
# check that the folders exist
if src_folder is None:
raise Exception("Source folder not provided")
if file is None or not isinstance(file, File):
raise Exception(f"Invalid file: {file}")
if target_folder is None:
raise Exception("Target folder not provided")
if self.files.get(file.uuid):
self.files.pop(file.uuid)
self._files_by_name.pop(file.name)
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
if file is None:
raise Exception("File to be moved is None")
def quarantine(self):
"""Quarantines the File System Folder."""
if not self.is_quarantined:
self.is_quarantined = True
self.fs.sys_log.info(f"Quarantined folder ./{self.name}")
# check if file with name already exists
if target_folder.get_file_by_name(file.name):
raise Exception(f'Folder with name "{file.name}" already exists.')
def unquarantine(self):
"""Unquarantine of the File System Folder."""
if self.is_quarantined:
self.is_quarantined = False
self.fs.sys_log.info(f"Quarantined folder ./{self.name}")
# remove file from src
src_folder.remove_file(file)
def quarantine_status(self) -> bool:
"""Returns true if the folder is being quarantined."""
return self.is_quarantined
# add file to target
target_folder.add_file(file)
def copy_file(self, file: FileSystemFile, src_folder: FileSystemFolder, target_folder: FileSystemFolder):
class File(FileSystemItemABC):
"""
Class representing a file in the simulation.
:ivar Folder folder: The folder in which the file resides.
:ivar FileType file_type: The type of the file.
:ivar Optional[int] sim_size: The simulated file size.
:ivar bool real: Indicates if the file is actually a real file in the Node sim fs output.
:ivar Optional[Path] sim_path: The path if the file is real.
"""
folder: Folder
"The Folder the File is in."
file_type: FileType
"The type of File."
sim_size: Optional[int] = None
"The simulated file size."
real: bool = False
"Indicates whether the File is actually a real file in the Node sim fs output."
sim_path: Optional[Path] = None
"The Path if real is True."
def __init__(self, **kwargs):
"""
Copies a file from one folder to another.
Initialise File class.
can provide
:param: file: The file to move
:type: file: FileSystemFile
:param: src_folder: The folder where the file is located
:type: FileSystemFolder
:param: target_folder: The folder where the file should be moved to
:type: FileSystemFolder
:param name: The name of the file.
:param file_type: The FileType of the file
:param size: The size of the FileSystemItemABC
"""
if src_folder is None:
raise Exception("Source folder not provided")
has_extension = "." in kwargs["name"]
if target_folder is None:
raise Exception("Target folder not provided")
# Attempt to use the file type extension to set/override the FileType
if has_extension:
extension = kwargs["name"].split(".")[-1]
kwargs["file_type"] = get_file_type_from_extension(extension)
else:
# If the file name does not have a extension, override file type to FileType.UNKNOWN
if not kwargs["file_type"]:
kwargs["file_type"] = FileType.UNKNOWN
if kwargs["file_type"] != FileType.UNKNOWN:
kwargs["name"] = f"{kwargs['name']}.{kwargs['file_type'].name.lower()}"
if file is None:
raise Exception("File to be moved is None")
# set random file size if none provided
if not kwargs.get("sim_size"):
kwargs["sim_size"] = kwargs["file_type"].default_size
super().__init__(**kwargs)
if self.real:
self.sim_path = self.folder.fs.sim_root / self.path
if not self.sim_path.exists():
self.sim_path.parent.mkdir(exist_ok=True, parents=True)
with open(self.sim_path, mode="a"):
pass
# check if file with name already exists
if target_folder.get_file_by_name(file.name):
raise Exception(f'Folder with name "{file.name}" already exists.')
# add file to target
target_folder.add_file(file)
def get_file_by_id(self, file_id: str) -> FileSystemFile:
"""Checks if the file exists in any file system folders."""
for key in self.folders:
file = self.folders[key].get_file_by_id(file_id=file_id)
if file is not None:
return file
def get_folder_by_name(self, folder_name: str) -> Optional[FileSystemFolder]:
def make_copy(self, dst_folder: Folder) -> File:
"""
Returns a the first folder with a matching name.
Create a copy of the current File object in the given destination folder.
:return: Returns the first FileSydtemFolder with a matching name
:param Folder dst_folder: The destination folder for the copied file.
:return: A new File object that is a copy of the current file.
"""
matching_folder = None
for key in self.folders:
if self.folders[key].name == folder_name:
matching_folder = self.folders[key]
break
return matching_folder
return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"}))
def get_folder_by_id(self, folder_id: str) -> FileSystemFolder:
@property
def path(self) -> str:
"""
Checks if the folder exists.
Get the path of the file in the file system.
:param: folder_id: The id of the folder to find
:type: folder_id: str
:return: The full path of the file.
"""
return self.folders[folder_id]
return f"{self.folder.name}/{self.name}"
def get_random_file_type(self) -> FileSystemFileType:
@property
def size(self) -> int:
"""
Returns a random FileSystemFileTypeEnum.
Get the size of the file in bytes.
:return: A random file type Enum
:return: The size of the file in bytes.
"""
return choice(list(FileSystemFileType))
if self.real:
return os.path.getsize(self.sim_path)
return self.sim_size
def describe_state(self) -> Dict:
"""Produce a dictionary describing the current state of this object."""
state = super().describe_state()
state["size"] = self.size
state["file_type"] = self.file_type.name
return state

View File

@@ -1,55 +0,0 @@
from random import choice
from typing import Dict
from primaite.simulator.file_system.file_system_file_type import file_type_sizes_KB, FileSystemFileType
from primaite.simulator.file_system.file_system_item_abc import FileSystemItem
class FileSystemFile(FileSystemItem):
"""Class that represents a file in the simulation."""
file_type: FileSystemFileType = None
"""The type of the FileSystemFile"""
def __init__(self, **kwargs):
"""
Initialise FileSystemFile class.
:param name: The name of the file.
:type name: str
:param file_type: The FileSystemFileType of the file
:type file_type: Optional[FileSystemFileType]
:param size: The size of the FileSystemItem
:type size: Optional[float]
"""
# set random file type if none provided
# set random file type if none provided
if kwargs.get("file_type") is None:
kwargs["file_type"] = choice(list(FileSystemFileType))
# set random file size if none provided
if kwargs.get("size") is None:
kwargs["size"] = file_type_sizes_KB[kwargs["file_type"]]
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update(
{
"uuid": self.uuid,
"file_type": self.file_type.name,
}
)
return state

View File

@@ -1,132 +0,0 @@
from enum import Enum
class FileSystemFileType(str, Enum):
"""An enumeration of common file types."""
UNKNOWN = 0
"Unknown file type."
# Text formats
TXT = 1
"Plain text file."
DOC = 2
"Microsoft Word document (.doc)"
DOCX = 3
"Microsoft Word document (.docx)"
PDF = 4
"Portable Document Format."
HTML = 5
"HyperText Markup Language file."
XML = 6
"Extensible Markup Language file."
CSV = 7
"Comma-Separated Values file."
# Spreadsheet formats
XLS = 8
"Microsoft Excel file (.xls)"
XLSX = 9
"Microsoft Excel file (.xlsx)"
# Image formats
JPEG = 10
"JPEG image file."
PNG = 11
"PNG image file."
GIF = 12
"GIF image file."
BMP = 13
"Bitmap image file."
# Audio formats
MP3 = 14
"MP3 audio file."
WAV = 15
"WAV audio file."
# Video formats
MP4 = 16
"MP4 video file."
AVI = 17
"AVI video file."
MKV = 18
"MKV video file."
FLV = 19
"FLV video file."
# Presentation formats
PPT = 20
"Microsoft PowerPoint file (.ppt)"
PPTX = 21
"Microsoft PowerPoint file (.pptx)"
# Web formats
JS = 22
"JavaScript file."
CSS = 23
"Cascading Style Sheets file."
# Programming languages
PY = 24
"Python script file."
C = 25
"C source code file."
CPP = 26
"C++ source code file."
JAVA = 27
"Java source code file."
# Compressed file types
RAR = 28
"RAR archive file."
ZIP = 29
"ZIP archive file."
TAR = 30
"TAR archive file."
GZ = 31
"Gzip compressed file."
# Database file types
MDF = 32
"MS SQL Server primary database file"
NDF = 33
"MS SQL Server secondary database file"
LDF = 34
"MS SQL Server transaction log"
file_type_sizes_KB = {
FileSystemFileType.UNKNOWN: 0,
FileSystemFileType.TXT: 4,
FileSystemFileType.DOC: 50,
FileSystemFileType.DOCX: 30,
FileSystemFileType.PDF: 100,
FileSystemFileType.HTML: 15,
FileSystemFileType.XML: 10,
FileSystemFileType.CSV: 15,
FileSystemFileType.XLS: 100,
FileSystemFileType.XLSX: 25,
FileSystemFileType.JPEG: 100,
FileSystemFileType.PNG: 40,
FileSystemFileType.GIF: 30,
FileSystemFileType.BMP: 300,
FileSystemFileType.MP3: 5000,
FileSystemFileType.WAV: 25000,
FileSystemFileType.MP4: 25000,
FileSystemFileType.AVI: 50000,
FileSystemFileType.MKV: 50000,
FileSystemFileType.FLV: 15000,
FileSystemFileType.PPT: 200,
FileSystemFileType.PPTX: 100,
FileSystemFileType.JS: 10,
FileSystemFileType.CSS: 5,
FileSystemFileType.PY: 5,
FileSystemFileType.C: 5,
FileSystemFileType.CPP: 10,
FileSystemFileType.JAVA: 10,
FileSystemFileType.RAR: 1000,
FileSystemFileType.ZIP: 1000,
FileSystemFileType.TAR: 1000,
FileSystemFileType.GZ: 800,
}

View File

@@ -1,87 +0,0 @@
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_item_abc import FileSystemItem
_LOGGER = getLogger(__name__)
class FileSystemFolder(FileSystemItem):
"""Simulation FileSystemFolder."""
files: Dict[str, FileSystemFile] = {}
"""List of files stored in the folder."""
is_quarantined: bool = False
"""Flag that marks the folder as quarantined if true."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update(
{
"files": {uuid: file.describe_state() for uuid, file in self.files.items()},
"is_quarantined": self.is_quarantined,
}
)
return state
def get_file_by_id(self, file_id: str) -> FileSystemFile:
"""Return a FileSystemFile with the matching id."""
return self.files.get(file_id)
def get_file_by_name(self, file_name: str) -> FileSystemFile:
"""Return a FileSystemFile with the matching id."""
return next((f for f in list(self.files) if f.name == file_name), None)
def add_file(self, file: FileSystemFile):
"""Adds a file to the folder list."""
if file is None or not isinstance(file, FileSystemFile):
raise Exception(f"Invalid file: {file}")
# check if file with id already exists in folder
if file.uuid in self.files:
_LOGGER.debug(f"File with id {file.uuid} already exists in folder")
else:
# add to list
self.files[file.uuid] = file
self.size += file.size
def remove_file(self, file: Optional[FileSystemFile]):
"""
Removes a file from the folder list.
The method can take a FileSystemFile object or a file id.
:param: file: The file to remove
:type: Optional[FileSystemFile]
"""
if file is None or not isinstance(file, FileSystemFile):
raise Exception(f"Invalid file: {file}")
if self.files.get(file.uuid):
del self.files[file.uuid]
self.size -= file.size
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
def quarantine(self):
"""Quarantines the File System Folder."""
self.is_quarantined = True
def end_quarantine(self):
"""Ends the quarantine of the File System Folder."""
self.is_quarantined = False
def quarantine_status(self) -> bool:
"""Returns true if the folder is being quarantined."""
return self.is_quarantined

View File

@@ -1,31 +0,0 @@
from typing import Dict
from primaite.simulator.core import SimComponent
class FileSystemItem(SimComponent):
"""Abstract base class for FileSystemItems used in the file system simulation."""
name: str
"""The name of the FileSystemItem."""
size: float = 0
"""The size the item takes up on disk."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update(
{
"name": self.name,
"size": self.size,
}
)
return state

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
from enum import Enum
from random import choice
from typing import Any
class FileType(Enum):
"""An enumeration of common file types."""
UNKNOWN = 0
"Unknown file type."
# Text formats
TXT = 1
"Plain text file."
DOC = 2
"Microsoft Word document (.doc)"
DOCX = 3
"Microsoft Word document (.docx)"
PDF = 4
"Portable Document Format."
HTML = 5
"HyperText Markup Language file."
XML = 6
"Extensible Markup Language file."
CSV = 7
"Comma-Separated Values file."
# Spreadsheet formats
XLS = 8
"Microsoft Excel file (.xls)"
XLSX = 9
"Microsoft Excel file (.xlsx)"
# Image formats
JPEG = 10
"JPEG image file."
PNG = 11
"PNG image file."
GIF = 12
"GIF image file."
BMP = 13
"Bitmap image file."
# Audio formats
MP3 = 14
"MP3 audio file."
WAV = 15
"WAV audio file."
# Video formats
MP4 = 16
"MP4 video file."
AVI = 17
"AVI video file."
MKV = 18
"MKV video file."
FLV = 19
"FLV video file."
# Presentation formats
PPT = 20
"Microsoft PowerPoint file (.ppt)"
PPTX = 21
"Microsoft PowerPoint file (.pptx)"
# Web formats
JS = 22
"JavaScript file."
CSS = 23
"Cascading Style Sheets file."
# Programming languages
PY = 24
"Python script file."
C = 25
"C source code file."
CPP = 26
"C++ source code file."
JAVA = 27
"Java source code file."
# Compressed file types
RAR = 28
"RAR archive file."
ZIP = 29
"ZIP archive file."
TAR = 30
"TAR archive file."
GZ = 31
"Gzip compressed file."
# Database file types
DB = 32
"Generic DB file. Used by sqlite3."
@classmethod
def _missing_(cls, value: Any) -> FileType:
return cls.UNKNOWN
@classmethod
def random(cls) -> FileType:
"""
Returns a random FileType.
:return: A random FileType.
"""
return choice(list(FileType))
@property
def default_size(self) -> int:
"""
Get the default size of the FileType in bytes.
Returns 0 if a default size does not exist.
"""
size = file_type_sizes_bytes[self]
return size if size else 0
def get_file_type_from_extension(file_type_extension: str) -> FileType:
"""
Get a FileType from a file type extension.
If a matching extension does not exist, FileType.UNKNOWN is returned.
:param file_type_extension: A file type extension.
:return: A file type extension.
"""
try:
return FileType[file_type_extension.upper()]
except KeyError:
return FileType.UNKNOWN
file_type_sizes_bytes = {
FileType.UNKNOWN: 0,
FileType.TXT: 4096,
FileType.DOC: 51200,
FileType.DOCX: 30720,
FileType.PDF: 102400,
FileType.HTML: 15360,
FileType.XML: 10240,
FileType.CSV: 15360,
FileType.XLS: 102400,
FileType.XLSX: 25600,
FileType.JPEG: 102400,
FileType.PNG: 40960,
FileType.GIF: 30720,
FileType.BMP: 307200,
FileType.MP3: 5120000,
FileType.WAV: 25600000,
FileType.MP4: 25600000,
FileType.AVI: 51200000,
FileType.MKV: 51200000,
FileType.FLV: 15360000,
FileType.PPT: 204800,
FileType.PPTX: 102400,
FileType.JS: 10240,
FileType.CSS: 5120,
FileType.PY: 5120,
FileType.C: 5120,
FileType.CPP: 10240,
FileType.JAVA: 10240,
FileType.RAR: 1024000,
FileType.ZIP: 1024000,
FileType.TAR: 1024000,
FileType.GZ: 819200,
FileType.DB: 15360000,
}

View File

@@ -29,6 +29,8 @@ class Network(SimComponent):
nodes: Dict[str, Node] = {}
links: Dict[str, Link] = {}
_node_id_map: Dict[int, Node] = {}
_link_id_map: Dict[int, Node] = {}
def __init__(self, **kwargs):
"""
@@ -161,8 +163,8 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
"nodes": {i for i, node in self._node_id_map.items()},
"links": {i: link.describe_state() for i, link in self._link_id_map.items()},
}
)
return state
@@ -179,9 +181,10 @@ class Network(SimComponent):
_LOGGER.warning(f"Can't add node {node.uuid}. It is already in the network.")
return
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -209,6 +212,10 @@ class Network(SimComponent):
_LOGGER.warning(f"Can't remove node {node.uuid}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
if node == _node:
self._node_id_map.pop(i)
break
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
@@ -235,9 +242,10 @@ class Network(SimComponent):
return
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
link.parent = self
_LOGGER.info(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
def remove_link(self, link: Link) -> None:
"""Disconnect a link from the network.
@@ -248,6 +256,10 @@ class Network(SimComponent):
link.endpoint_a.disconnect_link()
link.endpoint_b.disconnect_link()
self.links.pop(link.uuid)
for i, _link in self._link_id_map.items():
if link == _link:
self._link_id_map.pop(i)
break
link.parent = None
_LOGGER.info(f"Removed link {link.uuid} from network {self.uuid}.")

View File

@@ -4,12 +4,14 @@ import re
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
@@ -184,7 +186,7 @@ class NIC(SimComponent):
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.info(f"NIC {self} disabled")
_LOGGER.debug(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
@@ -206,7 +208,7 @@ class NIC(SimComponent):
# TODO: Inform the Node that a link has been connected
self.connected_link = link
self.enable()
_LOGGER.info(f"NIC {self} connected to Link {link}")
_LOGGER.debug(f"NIC {self} connected to Link {link}")
def disconnect_link(self):
"""Disconnect the NIC from the connected Link."""
@@ -349,7 +351,7 @@ class SwitchPort(SimComponent):
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.info(f"SwitchPort {self} disabled")
_LOGGER.debug(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
@@ -369,7 +371,7 @@ class SwitchPort(SimComponent):
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
_LOGGER.info(f"SwitchPort {self} connected to Link {link}")
_LOGGER.debug(f"SwitchPort {self} connected to Link {link}")
self.enable()
def disconnect_link(self):
@@ -475,13 +477,13 @@ class Link(SimComponent):
def endpoint_up(self):
"""Let the Link know and endpoint has been brought up."""
if self.is_up:
_LOGGER.info(f"Link {self} up")
_LOGGER.debug(f"Link {self} up")
def endpoint_down(self):
"""Let the Link know and endpoint has been brought down."""
if not self.is_up:
self.current_load = 0.0
_LOGGER.info(f"Link {self} down")
_LOGGER.debug(f"Link {self} down")
@property
def is_up(self) -> bool:
@@ -508,7 +510,7 @@ class Link(SimComponent):
"""
can_transmit = self._can_transmit(frame)
if not can_transmit:
_LOGGER.info(f"Cannot transmit frame as {self} is at capacity")
_LOGGER.debug(f"Cannot transmit frame as {self} is at capacity")
return False
receiver = self.endpoint_a
@@ -520,7 +522,7 @@ class Link(SimComponent):
# Frame transmitted successfully
# Load the frame size on the link
self.current_load += frame_size
_LOGGER.info(
_LOGGER.debug(
f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits "
f"({self.current_load_percent})"
)
@@ -890,6 +892,8 @@ class Node(SimComponent):
"All processes on the node."
file_system: FileSystem
"The nodes file system."
root: Path
"Root directory for simulation output."
sys_log: SysLog
arp: ARPCache
icmp: ICMP
@@ -917,14 +921,19 @@ class Node(SimComponent):
kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
kwargs["software_manager"] = SoftwareManager(
sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager")
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
)
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem()
super().__init__(**kwargs)
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
def describe_state(self) -> Dict:
"""
@@ -950,7 +959,25 @@ class Node(SimComponent):
)
return state
def show(self, markdown: bool = False):
def show(self, markdown: bool = False, component: Literal["NIC", "OPEN_PORTS"] = "NIC"):
"""A multi-use .show function that accepts either NIC or OPEN_PORTS."""
if component == "NIC":
self._show_nic(markdown)
elif component == "OPEN_PORTS":
self._show_open_ports(markdown)
def _show_open_ports(self, markdown: bool = False):
"""Prints a table of the open ports on the Node."""
table = PrettyTable(["Port", "Name"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Open Ports"
for port in self.software_manager.get_open_ports():
table.add_row([port.value, port.name])
print(table)
def _show_nic(self, markdown: bool = False):
"""Prints a table of the NICs on the Node."""
table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"])
if markdown:
@@ -1039,29 +1066,30 @@ class Node(SimComponent):
:param pings: The number of pings to attempt, default is 4.
:return: True if the ping is successful, otherwise False.
"""
if not isinstance(target_ip_address, IPv4Address):
target_ip_address = IPv4Address(target_ip_address)
if target_ip_address.is_loopback:
self.sys_log.info("Pinging loopback address")
return any(nic.enabled for nic in self.nics.values())
if self.operating_state == NodeOperatingState.ON:
self.sys_log.info(f"Pinging {target_ip_address}:")
sequence, identifier = 0, None
while sequence < pings:
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
request_replies = self.icmp.request_replies.get(identifier)
passed = request_replies == pings
if request_replies:
self.icmp.request_replies.pop(identifier)
else:
request_replies = 0
self.sys_log.info(
f"Ping statistics for {target_ip_address}: "
f"Packets: Sent = {pings}, "
f"Received = {request_replies}, "
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
)
return passed
if not isinstance(target_ip_address, IPv4Address):
target_ip_address = IPv4Address(target_ip_address)
if target_ip_address.is_loopback:
self.sys_log.info("Pinging loopback address")
return any(nic.enabled for nic in self.nics.values())
if self.operating_state == NodeOperatingState.ON:
self.sys_log.info(f"Pinging {target_ip_address}:")
sequence, identifier = 0, None
while sequence < pings:
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
request_replies = self.icmp.request_replies.get(identifier)
passed = request_replies == pings
if request_replies:
self.icmp.request_replies.pop(identifier)
else:
request_replies = 0
self.sys_log.info(
f"Ping statistics for {target_ip_address}: "
f"Packets: Sent = {pings}, "
f"Received = {request_replies}, "
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
)
return passed
return False
def send_frame(self, frame: Frame):
@@ -1070,7 +1098,8 @@ class Node(SimComponent):
:param frame: The Frame to be sent.
"""
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
if self.operating_state == NodeOperatingState.ON:
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
nic.send_frame(frame)
def receive_frame(self, frame: Frame, from_nic: NIC):
@@ -1083,18 +1112,27 @@ class Node(SimComponent):
:param frame: The Frame being received.
:param from_nic: The NIC that received the frame.
"""
if frame.ip:
if frame.ip.src_ip_address in self.arp:
self.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
)
if frame.ip.protocol == IPProtocol.TCP:
if frame.tcp.src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
elif frame.ip.protocol == IPProtocol.UDP:
pass
elif frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
if self.operating_state == NodeOperatingState.ON:
if frame.ip:
if frame.ip.src_ip_address in self.arp:
self.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
)
if frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
return
# Check if the destination port is open on the Node
if frame.tcp.dst_port in self.software_manager.get_open_ports():
# accept thr frame as the port is open
if frame.tcp.src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
else:
self.session_manager.receive_frame(frame)
else:
# denied as port closed
self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}")
# TODO: do we need to do anything more here?
pass
def install_service(self, service: Service) -> None:
"""
@@ -1110,7 +1148,7 @@ class Node(SimComponent):
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.info(f"Added service {service.uuid} to node {self.uuid}")
_LOGGER.debug(f"Added service {service.uuid} to node {self.uuid}")
def uninstall_service(self, service: Service) -> None:
"""Uninstall and completely remove service from this node.
@@ -1125,7 +1163,7 @@ class Node(SimComponent):
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
_LOGGER.info(f"Removed service {service.uuid} from node {self.uuid}")
_LOGGER.debug(f"Removed service {service.uuid} from node {self.uuid}")
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):

View File

@@ -1,3 +1,5 @@
from ipaddress import IPv4Address
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC
from primaite.simulator.network.hardware.nodes.computer import Computer
@@ -6,6 +8,9 @@ from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def client_server_routed() -> Network:
@@ -125,6 +130,9 @@ def arcd_uc2_network() -> Network:
)
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
# Client 2
client_2 = Computer(
@@ -143,13 +151,6 @@ def arcd_uc2_network() -> Network:
domain_controller.power_on()
network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
# Web Server
web_server = Server(
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
web_server.power_on()
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
# Database Server
database_server = Server(
hostname="database_server",
@@ -160,6 +161,51 @@ def arcd_uc2_network() -> Network:
database_server.power_on()
network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])
ddl = """
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(50) NOT NULL,
email VARCHAR(50) NOT NULL,
age INT,
city VARCHAR(50),
occupation VARCHAR(50)
);"""
user_insert_statements = [
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa
]
database_server.software_manager.install(DatabaseService)
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
database_service.start()
database_service._process_sql(ddl, None) # noqa
for insert_statement in user_insert_statements:
database_service._process_sql(insert_statement, None) # noqa
# Web Server
web_server = Server(
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
database_client.run()
database_client.connect()
# Backup Server
backup_server = Server(
hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
@@ -183,4 +229,6 @@ def arcd_uc2_network() -> Network:
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER)
return network

View File

@@ -23,9 +23,9 @@ class Application(IOSoftware):
Applications are user-facing programs that may perform input/output operations.
"""
operating_state: ApplicationOperatingState
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
"The current operating state of the Application."
execution_control_status: str
execution_control_status: str = "manual"
"Control status of the application's execution. It could be 'manual' or 'automatic'."
num_executions: int = 0
"The number of times the application has been executed. Default is 0."
@@ -53,6 +53,25 @@ class Application(IOSoftware):
)
return state
def run(self) -> None:
"""Open the Application."""
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Running Application {self.name}")
self.operating_state = ApplicationOperatingState.RUNNING
def close(self) -> None:
"""Close the Application."""
if self.operating_state == ApplicationOperatingState.RUNNING:
self.sys_log.info(f"Closed Application{self.name}")
self.operating_state = ApplicationOperatingState.CLOSED
def install(self) -> None:
"""Install Application."""
super().install()
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Installing Application {self.name}")
self.operating_state = ApplicationOperatingState.INSTALLING
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.

View File

@@ -0,0 +1,157 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from uuid import uuid4
from prettytable import PrettyTable
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.software_manager import SoftwareManager
class DatabaseClient(Application):
"""
A DatabaseClient application.
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
Database Service. It mainly operates over TCP protocol.
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
"""
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
:return: A dictionary representing the current state.
"""
pass
return super().describe_state()
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
"""
Configure the DatabaseClient to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
"""
self.server_ip_address = server_ip_address
self.server_password = server_password
self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self) -> bool:
"""Connect to a Database Service."""
if not self.connected and self.operating_state.RUNNING:
return self._connect(self.server_ip_address, self.server_password)
return False
def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
) -> bool:
if is_reattempt:
if self.connected:
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised")
self.server_ip_address = server_ip_address
return self.connected
else:
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined")
return False
payload = {"type": "connect_request", "password": password}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
)
return self._connect(server_ip_address, password, True)
def disconnect(self):
"""Disconnect from the Database Service."""
if self.connected and self.operating_state.RUNNING:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
)
self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}")
self.server_ip_address = None
self.connected = False
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
if is_reattempt:
success = self._query_success_tracker.get(query_id)
if success:
return True
return False
else:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "sql", "sql": sql, "uuid": query_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
self.operating_state = ApplicationOperatingState.RUNNING
self.connect()
def query(self, sql: str) -> bool:
"""
Send a query to the Database Service.
:param sql: The SQL query.
:return: True if the query was successful, otherwise False.
"""
if self.connected and self.operating_state.RUNNING:
query_id = str(uuid4())
# Initialise the tracker of this ID to False
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
def _print_data(self, data: Dict):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
if data:
table = PrettyTable(list(data.values())[0])
table.align = "l"
table.title = f"{self.sys_log.hostname} Database Client"
for row in data.values():
table.add_row(row.values())
print(table)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
self.connected = payload["response"] == True
elif payload["type"] == "sql":
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
self._print_data(payload["data"])
return True

View File

@@ -1,8 +1,9 @@
import json
import logging
from pathlib import Path
from typing import Optional
from typing import Any, Dict, List, Optional
from primaite.simulator import TEMP_SIM_OUTPUT
from primaite.simulator import SIM_OUTPUT
class _JSONFilter(logging.Filter):
@@ -51,6 +52,18 @@ class PacketCapture:
self.logger.addFilter(_JSONFilter())
def read(self) -> List[Dict[str, Any]]:
"""
Read packet capture logs and return them as a list of dictionaries.
:return: List of frames captured, represented as dictionaries.
"""
frames = []
with open(self._get_log_path(), "r") as file:
while line := file.readline():
frames.append(json.loads(line.rstrip()))
return frames
@property
def _logger_name(self) -> str:
"""Get PCAP the logger name."""
@@ -62,7 +75,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = TEMP_SIM_OUTPUT / self.hostname
root = SIM_OUTPUT / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -32,27 +32,23 @@ class Session(SimComponent):
"""
protocol: IPProtocol
src_ip_address: IPv4Address
dst_ip_address: IPv4Address
with_ip_address: IPv4Address
src_port: Optional[Port]
dst_port: Optional[Port]
connected: bool = False
@classmethod
def from_session_key(
cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]
) -> Session:
def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session:
"""
Create a Session instance from a session key tuple.
:param session_key: Tuple containing the session details.
:return: A Session instance.
"""
protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key
protocol, with_ip_address, src_port, dst_port = session_key
return Session(
protocol=protocol,
src_ip_address=src_ip_address,
dst_ip_address=dst_ip_address,
with_ip_address=with_ip_address,
src_port=src_port,
dst_port=dst_port,
)
@@ -99,8 +95,8 @@ class SessionManager:
@staticmethod
def _get_session_key(
frame: Frame, from_source: bool = True
) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]:
frame: Frame, inbound_frame: bool = True
) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]:
"""
Extracts the session key from the given frame.
@@ -112,36 +108,36 @@ class SessionManager:
- Optional[Port]: The destination port number (if applicable).
:param frame: The network frame from which to extract the session key.
:param from_source: A flag to indicate if the key should be extracted from the source or destination.
:return: A tuple containing the session key.
"""
protocol = frame.ip.protocol
src_ip_address = frame.ip.src_ip_address
dst_ip_address = frame.ip.dst_ip_address
with_ip_address = frame.ip.src_ip_address
if protocol == IPProtocol.TCP:
if from_source:
if inbound_frame:
src_port = frame.tcp.src_port
dst_port = frame.tcp.dst_port
else:
dst_port = frame.tcp.src_port
src_port = frame.tcp.dst_port
with_ip_address = frame.ip.dst_ip_address
elif protocol == IPProtocol.UDP:
if from_source:
if inbound_frame:
src_port = frame.udp.src_port
dst_port = frame.udp.dst_port
else:
dst_port = frame.udp.src_port
src_port = frame.udp.dst_port
with_ip_address = frame.ip.dst_ip_address
else:
src_port = None
dst_port = None
return protocol, src_ip_address, dst_ip_address, src_port, dst_port
return protocol, with_ip_address, src_port, dst_port
def receive_payload_from_software_manager(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
dst_ip_address: Optional[IPv4Address] = None,
dst_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Union[Any, None]:
@@ -154,20 +150,21 @@ class SessionManager:
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
"""
if session_id:
dest_ip_address = self.sessions_by_uuid[session_id].dst_ip_address
dest_port = self.sessions_by_uuid[session_id].dst_port
session = self.sessions_by_uuid[session_id]
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
dst_port = self.sessions_by_uuid[session_id].dst_port
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dest_ip_address)
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
if dst_mac_address:
outbound_nic = self.arp_cache.get_arp_cache_nic(dest_ip_address)
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
else:
if not is_reattempt:
self.arp_cache.send_arp_request(dest_ip_address)
self.arp_cache.send_arp_request(dst_ip_address)
return self.receive_payload_from_software_manager(
payload=payload,
dest_ip_address=dest_ip_address,
dest_port=dest_port,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
session_id=session_id,
is_reattempt=True,
)
@@ -178,17 +175,17 @@ class SessionManager:
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
ip=IPPacket(
src_ip_address=outbound_nic.ip_address,
dst_ip_address=dest_ip_address,
dst_ip_address=dst_ip_address,
),
tcp=TCPHeader(
src_port=dest_port,
dst_port=dest_port,
src_port=dst_port,
dst_port=dst_port,
),
payload=payload,
)
if not session_id:
session_key = self._get_session_key(frame, from_source=True)
session_key = self._get_session_key(frame, inbound_frame=False)
session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
@@ -198,33 +195,25 @@ class SessionManager:
outbound_nic.send_frame(frame)
def send_payload_to_software_manager(self, payload: Any, session_id: int):
def receive_frame(self, frame: Frame):
"""
Send a payload to the software manager.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload originates from.
"""
self.software_manager.receive_payload_from_session_manger()
def receive_payload_from_nic(self, frame: Frame):
"""
Receive a Frame from the NIC.
Receive a Frame.
Extract the session key using the _get_session_key method, and forward the payload to the appropriate
session. If the session does not exist, a new one is created.
:param frame: The frame being received.
"""
session_key = self._get_session_key(frame)
session = self.sessions_by_key.get(session_key)
session_key = self._get_session_key(frame, inbound_frame=True)
session: Session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
session = Session.from_session_key(session_key)
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
self.software_manager.receive_payload_from_session_manger(payload=frame, session=session)
# TODO: Implement the frame deconstruction and send to SoftwareManager.
self.software_manager.receive_payload_from_session_manger(
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
)
def show(self, markdown: bool = False):
"""

View File

@@ -1,15 +1,15 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import SoftwareType
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import IOSoftware
if TYPE_CHECKING:
from primaite.simulator.system.core.session_manager import SessionManager
@@ -17,76 +17,90 @@ if TYPE_CHECKING:
from typing import Type, TypeVar
ServiceClass = TypeVar("ServiceClass", bound=Service)
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
class SoftwareManager:
"""A class that manages all running Services and Applications on a Node and facilitates their communication."""
def __init__(self, session_manager: "SessionManager", sys_log: "SysLog"):
def __init__(self, session_manager: "SessionManager", sys_log: SysLog, file_system: FileSystem):
"""
Initialize a new instance of SoftwareManager.
:param session_manager: The session manager handling network communications.
"""
self.session_manager = session_manager
self.services: Dict[str, Service] = {}
self.applications: Dict[str, Application] = {}
self.software: Dict[str, Union[Service, Application]] = {}
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
def add_service(self, service_class: Type[ServiceClass]):
def get_open_ports(self) -> List[Port]:
"""
Add a Service to the manager.
Get a list of open ports.
:param: service_class: The class of the service to add
:return: A list of all open ports on the Node.
"""
service = service_class(software_manager=self, sys_log=self.sys_log)
open_ports = [Port.ARP]
for software in self.port_protocol_mapping.values():
if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}:
open_ports.append(software.port)
open_ports.sort(key=lambda port: port.value)
return open_ports
service.software_manager = self
self.services[service.name] = service
self.port_protocol_mapping[(service.port, service.protocol)] = service
def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol):
def install(self, software_class: Type[IOSoftwareClass]):
"""
Add an Application to the manager.
Install an Application or Service.
:param name: The name of the application.
:param application: The application instance.
:param port: The port used by the application.
:param protocol: The network protocol used by the application.
:param software_class: The software class.
"""
application.software_manager = self
self.applications[name] = application
self.port_protocol_mapping[(port, protocol)] = application
if software_class in self._software_class_to_name_map:
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
return
software = software_class(software_manager=self, sys_log=self.sys_log, file_system=self.file_system)
if isinstance(software, Application):
software.install()
software.software_manager = self
self.software[software.name] = software
self.port_protocol_mapping[(software.port, software.protocol)] = software
self.sys_log.info(f"Installed {software.name}")
if isinstance(software, Application):
software.operating_state = ApplicationOperatingState.CLOSED
def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any):
def uninstall(self, software_name: str):
"""
Uninstall an Application or Service.
:param software_name: The software name.
"""
if software_name in self.software:
software = self.software.pop(software_name) # noqa
del software
self.sys_log.info(f"Deleted {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
def send_internal_payload(self, target_software: str, payload: Any):
"""
Send a payload to a specific service or application.
:param target_software: The name of the target service or application.
:param target_software_type: The type of software (Service, Application, Process).
:param payload: The data to be sent.
:param receiver_type: The type of the target, either 'service' or 'application'.
"""
if target_software_type is SoftwareType.SERVICE:
receiver = self.services.get(target_software)
elif target_software_type is SoftwareType.APPLICATION:
receiver = self.applications.get(target_software)
else:
raise ValueError(f"Invalid receiver type {target_software_type}")
receiver = self.software.get(target_software)
if receiver:
receiver.receive_payload(payload)
else:
raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}")
self.sys_log.error(f"No Service of Application found with the name {target_software}")
def send_payload_to_session_manager(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[int] = None,
session_id: Optional[str] = None,
):
"""
Send a payload to the SessionManager.
@@ -97,21 +111,21 @@ class SoftwareManager:
:param session_id: The Session ID the payload is to originate from. Optional.
"""
self.session_manager.receive_payload_from_software_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
)
def receive_payload_from_session_manger(self, payload: Any, session: Session):
def receive_payload_from_session_manger(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
"""
Receive a payload from the SessionManager and forward it to the corresponding service or application.
:param payload: The payload being received.
:param session: The transport session the payload originates from.
"""
# receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
# if receiver:
# receiver.receive_payload(None, payload)
# else:
# raise ValueError(f"No service or application found for port {port} and protocol {protocol}")
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
if receiver:
receiver.receive(payload=payload, session_id=session_id)
else:
self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}")
pass
def show(self, markdown: bool = False):
@@ -120,13 +134,20 @@ class SoftwareManager:
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(["Name", "Operating State", "Health State", "Port"])
table = PrettyTable(["Name", "Type", "Operating State", "Health State", "Port"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} Software Manager"
for service in self.services.values():
for software in self.port_protocol_mapping.values():
software_type = "Service" if isinstance(software, Service) else "Application"
table.add_row(
[service.name, service.operating_state.name, service.health_state_actual.name, service.port.value]
[
software.name,
software_type,
software.operating_state.name,
software.health_state_actual.name,
software.port.value,
]
)
print(table)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator import TEMP_SIM_OUTPUT
from primaite.simulator import SIM_OUTPUT
class _NotJSONFilter(logging.Filter):
@@ -81,7 +81,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = TEMP_SIM_OUTPUT / self.hostname
root = SIM_OUTPUT / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -1,76 +0,0 @@
from typing import Dict
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.system.services.service import Service
class DatabaseService(Service):
"""Service loosely modelled on Microsoft SQL Server."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def uninstall(self) -> None:
"""
Undo installation procedure.
This method deletes files created when installing the database, and the database folder if it is empty.
"""
super().uninstall()
node: Node = self.parent
node.file_system.delete_file(self.primary_store)
node.file_system.delete_file(self.transaction_log)
if self.secondary_store:
node.file_system.delete_file(self.secondary_store)
if len(self.folder.files) == 0:
node.file_system.delete_folder(self.folder)
def install(self) -> None:
"""Perform first time install on a node, creating necessary files."""
super().install()
assert isinstance(self.parent, Node), "Database install can only happen after the db service is added to a node"
self._setup_files()
def _setup_files(
self,
db_size: int = 1000,
use_secondary_db_file: bool = False,
secondary_db_size: int = 300,
folder_name: str = "database",
):
"""Set up files that are required by the database on the parent host.
:param db_size: Initial file size of the main database file, defaults to 1000
:type db_size: int, optional
:param use_secondary_db_file: Whether to use a secondary database file, defaults to False
:type use_secondary_db_file: bool, optional
:param secondary_db_size: Size of the secondary db file, defaults to None
:type secondary_db_size: int, optional
:param folder_name: Name of the folder which will be setup to hold the db files, defaults to "database"
:type folder_name: str, optional
"""
# note that this parent.file_system.create_folder call in the future will be authenticated by using permissions
# handler. This permission will be granted based on service account given to the database service.
self.parent: Node
self.folder = self.parent.file_system.create_folder(folder_name)
self.primary_store = self.parent.file_system.create_file(
"db_primary_store", db_size, FileSystemFileType.MDF, folder=self.folder
)
self.transaction_log = self.parent.file_system.create_file(
"db_transaction_log", "1", FileSystemFileType.LDF, folder=self.folder
)
if use_secondary_db_file:
self.secondary_store = self.parent.file_system.create_file(
"db_secondary_store", secondary_db_size, FileSystemFileType.NDF, folder=self.folder
)
else:
self.secondary_store = None

View File

@@ -0,0 +1,155 @@
import sqlite3
from datetime import datetime
from sqlite3 import OperationalError
from typing import Any, Dict, List, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
class DatabaseService(Service):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
"""
password: Optional[str] = None
connections: Dict[str, datetime] = {}
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseService"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
self._conn = sqlite3.connect(self._db_file.sim_path)
self._cursor = self._conn.cursor()
def tables(self) -> List[str]:
"""
Get a list of table names present in the database.
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql)
return [row[0] for row in results["data"]]
def show(self, markdown: bool = False):
"""
Prints a list of table names in the database using PrettyTable.
:param markdown: Whether to output the table in Markdown format.
"""
table = PrettyTable(["Table"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.file_system.sys_log.hostname} Database"
for row in self.tables():
table.add_row([row])
print(table)
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
self.folder = self._db_file.folder
def _process_connect(
self, session_id: str, password: Optional[str] = None
) -> Dict[str, Union[int, Dict[str, bool]]]:
status_code = 500 # Default internal server error
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.password == password:
status_code = 200 # ok
self.connections[session_id] = datetime.now()
self.sys_log.info(f"Connect request for {session_id=} authorised")
else:
status_code = 401 # Unauthorised
self.sys_log.info(f"Connect request for {session_id=} declined")
else:
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
try:
self._cursor.execute(query)
self._conn.commit()
except OperationalError:
# Handle the case where the table does not exist.
self.sys_log.error(f"{self.name}: Error, query failed")
return {"status_code": 404, "data": {}}
data = []
description = self._cursor.description
if description:
headers = []
for header in description:
headers.append(header[0])
data = self._cursor.fetchall()
if data and headers:
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Processes the incoming SQL payload and sends the result back.
:param payload: The SQL query to be executed.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
result = {"status_code": 500, "data": []}
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
result = self._process_connect(session_id=session_id, password=payload.get("password"))
elif payload["type"] == "disconnect":
if session_id in self.connections:
self.connections.pop(session_id)
elif payload["type"] == "sql":
if session_id in self.connections:
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
else:
result = {"status_code": 401, "type": "sql"}
self.send(payload=result, session_id=session_id)
return True
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Send a SQL response back down to the SessionManager.
:param payload: The SQL query results.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return payload["status_code"] == 200

View File

@@ -0,0 +1,49 @@
from ipaddress import IPv4Address
from typing import Optional
from primaite.simulator.system.applications.database_client import DatabaseClient
class DataManipulationBot(DatabaseClient):
"""
Red Agent Data Integration Service.
The Service represents a bot that causes files/folders in the File System to
become corrupted.
"""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def configure(
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The data manipulation query payload.
"""
self.server_ip_address = server_ip_address
self.payload = payload
self.server_password = server_password
self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}.")
def run(self):
"""Run the DataManipulationBot."""
if self.server_ip_address and self.payload:
self.sys_log.info(f"Attempting to start the {self.name}")
super().run()
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
else:
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_io_address and payload.")

View File

@@ -1,34 +0,0 @@
from ipaddress import IPv4Address
from typing import Any, Optional
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
class DataManipulatorService(Service):
"""
Red Agent Data Integration Service.
The Service represents a bot that causes files/folders in the File System to
become corrupted.
"""
def __init__(self, **kwargs):
kwargs["name"] = "DataManipulatorBot"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def start(self, target_ip_address: IPv4Address, payload: Optional[Any] = "DELETE TABLE users", **kwargs):
"""
Run the DataManipulatorService actions.
:param: target_ip_address: The IP address of the target machine to attack
:param: payload: The payload to send to the target machine
"""
super().start()
self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=target_ip_address, dest_port=self.port
)

View File

@@ -98,35 +98,30 @@ class Service(IOSoftware):
def stop(self) -> None:
"""Stop the service."""
_LOGGER.debug(f"Stopping service {self.name}")
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Stopping service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
def start(self, **kwargs) -> None:
"""Start the service."""
_LOGGER.debug(f"Starting service {self.name}")
if self.operating_state == ServiceOperatingState.STOPPED:
self.sys_log.info(f"Starting service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
def pause(self) -> None:
"""Pause the service."""
_LOGGER.debug(f"Pausing service {self.name}")
if self.operating_state == ServiceOperatingState.RUNNING:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.PAUSED
def resume(self) -> None:
"""Resume paused service."""
_LOGGER.debug(f"Resuming service {self.name}")
if self.operating_state == ServiceOperatingState.PAUSED:
self.sys_log.info(f"Resuming service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
def restart(self) -> None:
"""Restart running service."""
_LOGGER.debug(f"Restarting service {self.name}")
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.RESTARTING
@@ -134,13 +129,11 @@ class Service(IOSoftware):
def disable(self) -> None:
"""Disable the service."""
_LOGGER.debug(f"Disabling service {self.name}")
self.sys_log.info(f"Disabling Application {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
def enable(self) -> None:
"""Enable the disabled service."""
_LOGGER.debug(f"Enabling service {self.name}")
if self.operating_state == ServiceOperatingState.DISABLED:
self.sys_log.info(f"Enabling Application {self.name}")
self.operating_state = ServiceOperatingState.STOPPED

View File

@@ -1,8 +1,9 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict
from typing import Any, Dict, Optional
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.sys_log import SysLog
@@ -79,6 +80,10 @@ class Software(SimComponent):
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."
file_system: FileSystem
"The FileSystem of the Node the Software is installed on."
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
@@ -216,7 +221,6 @@ class IOSoftware(Software):
:param kwargs: Additional keyword arguments specific to the implementation.
:return: True if the payload was successfully sent, False otherwise.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""

View File

@@ -12,6 +12,8 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.simulator.network.container import Network
from primaite.simulator.network.networks import arcd_uc2_network
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -19,7 +21,22 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
# PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node
@pytest.fixture(scope="function")
def uc2_network() -> Network:
return arcd_uc2_network()
@pytest.fixture(scope="function")
def file_system() -> FileSystem:
return Node(hostname="fs_node").file_system
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.

View File

View File

@@ -0,0 +1,25 @@
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def test_data_manipulation(uc2_network):
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
database_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = database_server.software_manager.software["DatabaseService"]
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
# First check that the DB client on the web_server can successfully query the users table on the database
assert db_client.query("SELECT * FROM user;")
# Now we run the DataManipulationBot
db_manipulation_bot.run()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT * FROM user;")

View File

@@ -1,52 +1,59 @@
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.database import DatabaseService
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
def test_installing_database():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
port=Port.SQL_SERVER,
operating_state=ServiceOperatingState.RUNNING,
)
def test_database_client_server_connection(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
node = Node(hostname="db-server")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
node.install_service(db)
assert len(db_service.connections) == 1
assert db in node
file_exists = False
for folder in node.file_system.folders.values():
for file in folder.files.values():
if file.name == "db_primary_store":
file_exists = True
break
if file_exists:
break
assert file_exists
db_client.disconnect()
assert len(db_service.connections) == 0
def test_uninstalling_database():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
port=Port.SQL_SERVER,
operating_state=ServiceOperatingState.RUNNING,
)
def test_database_client_server_correct_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
node = Node(hostname="db-server")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
node.install_service(db)
db_client.disconnect()
node.uninstall_service(db)
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345")
db_service.password = "12345"
assert db not in node
assert node.file_system.get_folder_by_name("database") is None
assert db_client.connect()
assert len(db_service.connections) == 1
def test_database_client_server_incorrect_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
db_service.password = "12345"
assert not db_client.connect()
assert len(db_service.connections) == 0
def test_database_client_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query("SELECT * FROM user;")

View File

@@ -1,132 +1,144 @@
import pytest
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_folder import FileSystemFolder
from primaite.simulator.file_system.file_type import FileType
def test_create_folder_and_file():
def test_create_folder_and_file(file_system):
"""Test creating a folder and a file."""
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) is 1
assert len(file_system.folders) == 1
file_system.create_folder(folder_name="test_folder")
file = file_system.create_file(file_name="test_file", size=10, folder_uuid=folder.uuid)
assert len(file_system.get_folder_by_id(folder.uuid).files) is 1
assert len(file_system.folders) is 2
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
assert file_system.get_file_by_id(file.uuid).name is "test_file"
assert file_system.get_file_by_id(file.uuid).size == 10
assert len(file_system.get_folder("test_folder").files) == 1
assert file_system.get_folder("test_folder").get_file("test_file.txt")
def test_create_file():
def test_create_file_no_folder(file_system):
"""Tests that creating a file without a folder creates a folder and sets that as the file's parent."""
file_system = FileSystem()
file = file_system.create_file(file_name="test_file", size=10)
file = file_system.create_file(file_name="test_file.txt", size=10)
assert len(file_system.folders) is 1
assert file_system.get_folder_by_name("root").get_file_by_id(file.uuid) is file
assert file_system.get_folder("root").get_file("test_file.txt") == file
assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT
assert file_system.get_folder("root").get_file("test_file.txt").size == 10
def test_delete_file():
def test_create_file_no_extension(file_system):
"""Tests that creating a file without an extension sets the file type to FileType.UNKNOWN."""
file = file_system.create_file(file_name="test_file")
assert len(file_system.folders) is 1
assert file_system.get_folder("root").get_file("test_file") == file
assert file_system.get_folder("root").get_file("test_file").file_type == FileType.UNKNOWN
assert file_system.get_folder("root").get_file("test_file").size == 0
def test_delete_file(file_system):
"""Tests that a file can be deleted."""
file_system = FileSystem()
file_system.create_file(file_name="test_file.txt")
assert len(file_system.folders) == 1
assert len(file_system.get_folder("root").files) == 1
file = file_system.create_file(file_name="test_file", size=10)
assert len(file_system.folders) is 1
folder_id = list(file_system.folders.keys())[0]
folder = file_system.get_folder_by_id(folder_id)
assert folder.get_file_by_id(file.uuid) is file
file_system.delete_file(file=file)
assert len(file_system.folders) is 1
assert len(folder.files) is 0
file_system.delete_file(folder_name="root", file_name="test_file.txt")
assert len(file_system.folders) == 1
assert len(file_system.get_folder("root").files) == 0
def test_delete_non_existent_file():
def test_delete_non_existent_file(file_system):
"""Tests deleting a non existent file."""
file_system = FileSystem()
file = file_system.create_file(file_name="test_file", size=10)
not_added_file = FileSystemFile(name="not_added")
file_system.create_file(file_name="test_file.txt")
# folder should be created
assert len(file_system.folders) is 1
assert len(file_system.folders) == 1
# should only have 1 file in the file system
folder_id = list(file_system.folders.keys())[0]
folder = file_system.get_folder_by_id(folder_id)
assert len(list(folder.files)) is 1
assert folder.get_file_by_id(file.uuid) is file
assert len(file_system.get_folder("root").files) == 1
# deleting should not change how many files are in folder
file_system.delete_file(file=not_added_file)
assert len(file_system.folders) is 1
assert len(list(folder.files)) is 1
file_system.delete_file(folder_name="root", file_name="does_not_exist!")
# should still only be one folder
assert len(file_system.folders) == 1
# The folder should still have 1 file
assert len(file_system.get_folder("root").files) == 1
def test_delete_folder():
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) is 1
def test_delete_folder(file_system):
file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) == 2
file_system.delete_folder(folder)
assert len(file_system.folders) is 0
file_system.delete_folder(folder_name="test_folder")
assert len(file_system.folders) == 1
def test_deleting_a_non_existent_folder():
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
not_added_folder = FileSystemFolder(name="fake_folder")
assert len(file_system.folders) is 1
def test_deleting_a_non_existent_folder(file_system):
file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) == 2
file_system.delete_folder(not_added_folder)
assert len(file_system.folders) is 1
file_system.delete_folder(folder_name="does not exist!")
assert len(file_system.folders) == 2
def test_move_file():
def test_deleting_root_folder_fails(file_system):
assert len(file_system.folders) == 1
file_system.delete_folder(folder_name="root")
assert len(file_system.folders) == 1
def test_move_file(file_system):
"""Tests the file move function."""
file_system = FileSystem()
src_folder = file_system.create_folder(folder_name="test_folder_1")
assert len(file_system.folders) is 1
file_system.create_folder(folder_name="src_folder")
file_system.create_folder(folder_name="dst_folder")
target_folder = file_system.create_folder(folder_name="test_folder_2")
assert len(file_system.folders) is 2
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder")
original_uuid = file.uuid
file = file_system.create_file(file_name="test_file", size=10, folder_uuid=src_folder.uuid)
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 0
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 0
file_system.move_file(file=file, src_folder=src_folder, target_folder=target_folder)
file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 0
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 1
assert len(file_system.get_folder("src_folder").files) == 0
assert len(file_system.get_folder("dst_folder").files) == 1
assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid
def test_copy_file():
def test_copy_file(file_system):
"""Tests the file copy function."""
file_system = FileSystem()
src_folder = file_system.create_folder(folder_name="test_folder_1")
assert len(file_system.folders) is 1
file_system.create_folder(folder_name="src_folder")
file_system.create_folder(folder_name="dst_folder")
target_folder = file_system.create_folder(folder_name="test_folder_2")
assert len(file_system.folders) is 2
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True)
original_uuid = file.uuid
file = file_system.create_file(file_name="test_file", size=10, folder_uuid=src_folder.uuid)
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 0
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 0
file_system.copy_file(file=file, src_folder=src_folder, target_folder=target_folder)
file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 1
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 1
assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid
def test_serialisation():
def test_folder_quarantine_state(file_system):
"""Tests the changing of folder quarantine status."""
folder = file_system.get_folder("root")
assert folder.quarantine_status() is False
folder.quarantine()
assert folder.quarantine_status() is True
folder.unquarantine()
assert folder.quarantine_status() is False
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_serialisation(file_system):
"""Test to check that the object serialisation works correctly."""
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) is 1
file_system.create_file(file_name="test_file", size=10, folder_uuid=folder.uuid)
assert file_system.get_folder_by_id(folder.uuid) is folder
file_system.create_file(file_name="test_file.txt")
serialised_file_sys = file_system.model_dump_json()
deserialised_file_sys = FileSystem.model_validate_json(serialised_file_sys)

View File

@@ -1,23 +0,0 @@
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
def test_file_type():
"""Tests tha the FileSystemFile type is set correctly."""
file = FileSystemFile(name="test", file_type=FileSystemFileType.DOC)
assert file.file_type is FileSystemFileType.DOC
def test_get_size():
"""Tests that the file size is being returned properly."""
file = FileSystemFile(name="test", size=1.5)
assert file.size == 1.5
def test_serialisation():
"""Test to check that the object serialisation works correctly."""
file = FileSystemFile(name="test", size=1.5, file_type=FileSystemFileType.DOC)
serialised_file = file.model_dump_json()
deserialised_file = FileSystemFile.model_validate_json(serialised_file)
assert file.model_dump_json() == deserialised_file.model_dump_json()

View File

@@ -1,75 +0,0 @@
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.file_system.file_system_folder import FileSystemFolder
def test_adding_removing_file():
"""Test the adding and removing of a file from a folder."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
assert folder.size == 10
assert len(folder.files) is 1
folder.remove_file(file)
assert folder.size == 0
assert len(folder.files) is 0
def test_remove_non_existent_file():
"""Test the removing of a file that does not exist."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
not_added_file = FileSystemFile(name="fake_file", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
assert folder.size == 10
assert len(folder.files) is 1
folder.remove_file(not_added_file)
assert folder.size == 10
assert len(folder.files) is 1
def test_get_file_by_id():
"""Test to make sure that the correct file is returned."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
file2 = FileSystemFile(name="test_file_2", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
folder.add_file(file2)
assert folder.size == 20
assert len(folder.files) is 2
assert folder.get_file_by_id(file_id=file.uuid) is file
def test_folder_quarantine_state():
"""Tests the changing of folder quarantine status."""
folder = FileSystemFolder(name="test")
assert folder.quarantine_status() is False
folder.quarantine()
assert folder.quarantine_status() is True
folder.end_quarantine()
assert folder.quarantine_status() is False
def test_serialisation():
"""Test to check that the object serialisation works correctly."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
serialised_folder = folder.model_dump_json()
deserialised_folder = FileSystemFolder.model_validate_json(serialised_folder)
assert folder.model_dump_json() == deserialised_folder.model_dump_json()

View File

@@ -1,5 +1,7 @@
import json
import pytest
from primaite.simulator.network.container import Network
@@ -10,6 +12,7 @@ def test_creating_container():
assert net.links == {}
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_describe_state():
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable."""
net = Network()

View File

@@ -0,0 +1,20 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def test_creation():
network = arcd_uc2_network()
client_1: Node = network.get_node_by_hostname("client_1")
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;"

View File

@@ -1,32 +0,0 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.red_services.data_manipulator_service import DataManipulatorService
def test_creation():
network = arcd_uc2_network()
client_1: Node = network.get_node_by_hostname("client_1")
client_1.software_manager.add_service(service_class=DataManipulatorService)
data_manipulator_service: DataManipulatorService = client_1.software_manager.services["DataManipulatorBot"]
assert data_manipulator_service.name == "DataManipulatorBot"
assert data_manipulator_service.port == Port.POSTGRES_SERVER
assert data_manipulator_service.protocol == IPProtocol.TCP
# should have no session yet
assert len(client_1.session_manager.sessions_by_uuid) == 0
try:
data_manipulator_service.start(target_ip_address=IPv4Address("192.168.1.14"))
except Exception as e:
assert False, f"Test was not supposed to throw exception: {e}"
# there should be a session after the service is started
assert len(client_1.session_manager.sessions_by_uuid) == 1

View File

@@ -1,15 +1,18 @@
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.database import DatabaseService
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
import json
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.system.services.database_service import DatabaseService
def test_creation():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
port=Port.SQL_SERVER,
operating_state=ServiceOperatingState.RUNNING,
)
@pytest.fixture(scope="function")
def database_server() -> Node:
node = Node(hostname="db_node")
node.software_manager.install(DatabaseService)
node.software_manager.software["DatabaseService"].start()
return node
def test_creation(database_server):
database_server.software_manager.show()