Merge remote-tracking branch 'origin/dev' into feature/1812-traverse-actions-dict

This commit is contained in:
Marek Wolan
2023-09-19 10:12:47 +01:00
54 changed files with 2552 additions and 1004 deletions

2
.gitignore vendored
View File

@@ -144,9 +144,11 @@ cython_debug/
# IDE
.idea/
docs/source/primaite-dependencies.rst
.vscode/
# outputs
src/primaite/outputs/
simulation_output/
# benchmark session outputs
benchmark/output

View File

@@ -27,6 +27,9 @@ SessionManager.
- File System - ability to emulate a node's file system during a simulation
- Example notebooks - There is currently 1 jupyter notebook which walks through using PrimAITE
1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP)
- Red Agent Services:
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database)
- DNS Services: DNS Client and DNS Server
## [2.0.0] - 2023-07-26

View File

@@ -98,6 +98,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/getting_started
source/about
source/config
source/simulation
source/primaite_session
source/custom_agent
PrimAITE API <source/_autosummary/primaite>

View File

@@ -21,3 +21,5 @@ Contents
simulation_components/network/router
simulation_components/network/switch
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/software

View File

@@ -2,7 +2,7 @@
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _about:
.. _network:
Network
=======

View File

@@ -2,7 +2,7 @@
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _about:
.. _router:
Router Module
=============

View File

@@ -0,0 +1,58 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
DataManipulationBot
===================
The ``DataManipulationBot`` class provides functionality to connect to a ``DatabaseService`` and execute malicious SQL statements.
Overview
--------
The bot is intended to simulate a malicious actor carrying out attacks like:
- Dropping tables
- Deleting records
- Modifying data
On a database server by abusing an application's trusted database connectivity.
Usage
-----
- Create an instance and call ``configure`` to set:
- Target database server IP
- Database password (if needed)
- SQL statement payload
- Call ``run`` to connect and execute the statement.
The bot handles connecting, executing the statement, and disconnecting.
Example
-------
.. code-block:: python
client_1 = Computer(
hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
)
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
data_manipulation_bot.run()
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.
Implementation
--------------
The bot extends ``DatabaseClient`` and leverages its connectivity.
- Uses the Application base class for lifecycle management.
- Credentials and target IP set via ``configure``.
- ``run`` handles connecting, executing statement, and disconnecting.
- SQL payload executed via ``query`` method.
- Results in malicious SQL being executed on remote database server.

View File

@@ -0,0 +1,70 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Database Client Server
======================
Database Service
----------------
The ``DatabaseService`` provides a SQL database server simulation by extending the base Service class.
Key capabilities
^^^^^^^^^^^^^^^^
- Initialises a SQLite database file in the ``Node``'s ``FileSystem`` upon creation.
- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs.
- Authenticates connections using a configurable password.
- Executes SQL queries against the SQLite database.
- Returns query results and status codes back to clients.
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Clients connect, execute queries, and disconnect.
- Service runs on TCP port 5432 by default.
Implementation
^^^^^^^^^^^^^^
- Uses SQLite for persistent storage.
- Creates the database file within the node's file system.
- Manages client connections in a dictionary by session ID.
- Processes SQL queries via the SQLite cursor and connection.
- Returns results and status codes in a standard dictionary format.
- Extends Service class for integration with ``SoftwareManager``.
Database Client
---------------
The DatabaseClient provides a client interface for connecting to the ``DatabaseService``.
Key features
^^^^^^^^^^^^
- Connects to the ``DatabaseService`` via the ``SoftwareManager``.
- Executes SQL queries and retrieves result sets.
- Handles connecting, querying, and disconnecting.
- Provides a simple ``query`` method for running SQL.
Usage
^^^^^
- Initialise with server IP address and optional password.
- Connect to the ``DatabaseService`` with ``connect``.
- Execute SQL queries via ``query``.
- Retrieve results in a dictionary.
- Disconnect when finished.
Implementation
^^^^^^^^^^^^^^
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Connect and disconnect methods manage sessions.
- Provides easy interface for applications to query database.
- Payloads serialised as dictionaries for transmission.
- Extends base Application class.

View File

@@ -0,0 +1,56 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
DNS Client Server
=================
DNS Server
----------
Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class.
Key capabilities
^^^^^^^^^^^^^^^^
- Simulates DNS requests and DNSPacket transfer across a network
- Registers domain names and the IP Address linked to the domain name
- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read
- Leverages the Service base class for install/uninstall, status tracking, etc.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future)
Implementation
^^^^^^^^^^^^^^
- DNS request and responses use a ``DNSPacket`` object
- Extends Service class for integration with ``SoftwareManager``.
DNS Client
----------
The DNSClient provides a client interface for connecting to the ``DNSServer``.
Key features
^^^^^^^^^^^^
- Connects to the ``DNSServer`` via the ``SoftwareManager``.
- Executes DNS lookup requests and keeps a cache of known domain name IP addresses.
- Handles connection to DNSServer and querying for domain name IP addresses.
Usage
^^^^^
- Install on a Node via the ``SoftwareManager`` to start the database service.
- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future)
- Execute domain name checks with ``check_domain_exists``.
- ``DNSClient`` will automatically add the IP Address of the domain into its cache
Implementation
^^^^^^^^^^^^^^
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Provides easy interface for Nodes to find IP addresses via domain names.
- Extends base Service class.

View File

@@ -0,0 +1,98 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _internal_frame_processing:
Internal Frame Processing
=========================
Inbound
-------
At the NIC
^^^^^^^^^^
When a Frame is received on the Node's NIC:
- The NIC checks if it is enabled. If so, it will process the Frame.
- The Frame's received timestamp is set.
- The Frame is captured by the NIC's PacketCapture if configured.
- The NIC decrements the IP Packet's TTL by 1.
- The NIC calls the Node's ``receive_frame`` method, passing itself as the receiving NIC and the Frame.
At the Node
^^^^^^^^^^^
When ``receive_frame`` is called on the Node:
- The source IP address is added to the ARP cache if not already present.
- The Frame's protocol is checked:
- If ARP or ICMP, the Frame is passed to that protocol's handler method.
- Otherwise it is passed to the SessionManager's ``receive_frame`` method.
At the SessionManager
^^^^^^^^^^^^^^^^^^^^^
When ``receive_frame`` is called on the SessionManager:
- It extracts the key session details from the Frame:
- Protocol (TCP, UDP, etc)
- Source IP
- Destination IP
- Source Port
- Destination Port
- It checks if an existing Session matches these details.
- If no match, a new Session is created to represent this exchange.
- The payload and new/existing Session ID are passed to the SoftwareManager's ``receive_payload_from_session_manager`` method.
At the SoftwareManager
^^^^^^^^^^^^^^^^^^^^^^
Inside ``receive_payload_from_session_manager``:
- The SoftwareManager checks its port/protocol mapping to find which Service or Application is listening on the destination port and protocol.
- The payload and Session ID are forwarded to that receiver Service/Application instance via their ``receive`` method.
- The Service/Application can then process the payload as needed.
Outbound
--------
At the Service/Application
^^^^^^^^^^^^^^^^^^^^^^^^^^
When a Service or Application needs to send a payload:
- It calls the SoftwareManager's ``send_payload_to_session_manager`` method.
- Passes the payload, and either destination IP and destination port for new payloads, or session id for existing sessions.
At the SoftwareManager
^^^^^^^^^^^^^^^^^^^^^^
Inside ``send_payload_to_session_manager``:
- The SoftwareManager forwards the payload and details through to to the SessionManager's ``receive_payload_from_software_manager`` method.
At the SessionManager
^^^^^^^^^^^^^^^^^^^^^
When ``receive_payload_from_software_manager`` is called:
- If a Session ID was provided, it looks up the Session.
- Gets the destination MAC address by checking the ARP cache.
- If no Session ID was provided, the destination Port, IP address and Mac Address are used along with the outbound IP Address and Mac Address to create a new Session.
- Calls `send_payload_to_nic`` to construct and send the Frame.
When ``send_payload_to_nic`` is called:
- It constructs a new Frame with the payload, using the source NIC's MAC, source IP, destination MAC, etc.
- The outbound NIC is looked up via the ARP cache based on destination IP.
- The constructed Frame is passed to the outbound NIC's ``send_frame`` method.
At the NIC
^^^^^^^^^^
When ``send_frame`` is called:
- The NIC checks if it is enabled before sending.
- If enabled, it sends the Frame out to the connected Link.

View File

@@ -0,0 +1,20 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Software
========
Contents
########
.. toctree::
:maxdepth: 8
database_client_server
data_manipulation_bot
dns_client_server

View File

@@ -1,5 +1,14 @@
from datetime import datetime
from primaite import _PRIMAITE_ROOT
TEMP_SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output"
SIM_OUTPUT = None
"A path at the repo root dir to use temporarily for sim output testing while in dev."
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
if not SIM_OUTPUT:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)

View File

@@ -149,8 +149,8 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.file_system.file_system_file_type import FileSystemFileType\n",
"from primaite.simulator.file_system.file_system_file import FileSystemFile"
"from primaite.simulator.file_system.file_type import FileType\n",
"from primaite.simulator.file_system.file_system import File"
]
},
{
@@ -160,7 +160,7 @@
"outputs": [],
"source": [
"my_pc_downloads_folder = my_pc.file_system.create_folder(\"downloads\")\n",
"my_pc_downloads_folder.add_file(FileSystemFile(name=\"firefox_installer.zip\",file_type=FileSystemFileType.ZIP))"
"my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",file_type=FileType.ZIP))"
]
},
{
@@ -171,7 +171,7 @@
{
"data": {
"text/plain": [
"FileSystemFile(uuid='7d56a563-ecc0-4011-8c97-240dd6c885c0', name='favicon.ico', size=40.0, file_type=<FileSystemFileType.PNG: '11'>, action_manager=None)"
"File(uuid='7d56a563-ecc0-4011-8c97-240dd6c885c0', name='favicon.ico', size=40.0, file_type=<FileType.PNG: '11'>, action_manager=None)"
]
},
"execution_count": 9,
@@ -181,7 +181,7 @@
],
"source": [
"my_server_folder = my_server.file_system.create_folder(\"static\")\n",
"my_server.file_system.create_file(\"favicon.ico\", file_type=FileSystemFileType.PNG)"
"my_server.file_system.create_file(\"favicon.ico\", file_type=FileType.PNG)"
]
},
{

View File

@@ -222,6 +222,9 @@ class SimComponent(BaseModel):
:param action: List describing the action to apply to this object.
:type action: List[str]
:param: context: Dict containing context for actions
:type context: Dict
"""
if self.action_manager is None:
return

View File

@@ -1,242 +1,519 @@
from random import choice
from __future__ import annotations
import math
import os.path
import shutil
from pathlib import Path
from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import SimComponent
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.file_system.file_system_folder import FileSystemFolder
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
from primaite.simulator.system.core.sys_log import SysLog
_LOGGER = getLogger(__name__)
class FileSystem(SimComponent):
"""Class that contains all the simulation File System."""
def convert_size(size_bytes: int) -> str:
"""
Convert a file size from bytes to a string with a more human-readable format.
folders: Dict[str, FileSystemFolder] = {}
"""List containing all the folders in the file system."""
This function takes the size of a file in bytes and converts it to a string representation with appropriate size
units (B, KB, MB, GB, etc.).
:param size_bytes: The size of the file in bytes.
:return: The human-readable string representation of the file size.
"""
if size_bytes == 0:
return "0 B"
# Tuple of size units starting from Bytes up to Yottabytes
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
# Calculate the index (i) that will be used to select the appropriate size unit from size_name
i = int(math.floor(math.log(size_bytes, 1024)))
# Calculate the adjusted size value (s) in terms of the new size unit
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{s} {size_name[i]}"
class FileSystemItemABC(SimComponent):
"""
Abstract base class for file system items used in the file system simulation.
:ivar name: The name of the FileSystemItemABC.
"""
name: str
"The name of the FileSystemItemABC."
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update({"folders": {uuid: folder.describe_state() for uuid, folder in self.folders.items()}})
state.update(
{
"name": self.name,
}
)
return state
def get_folders(self) -> Dict:
"""Returns the list of folders."""
return self.folders
@property
def size_str(self) -> str:
"""
Get the file size in a human-readable string format.
This property makes use of the :func:`convert_size` function to convert the `self.size` attribute to a string
that is easier to read and understand.
:return: The human-readable string representation of the file size.
"""
return convert_size(self.size)
class FileSystem(SimComponent):
"""Class that contains all the simulation File System."""
folders: Dict[str, Folder] = {}
"List containing all the folders in the file system."
_folders_by_name: Dict[str, Folder] = {}
sys_log: SysLog
sim_root: Path
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Ensure a default root folder
if not self.folders:
self.create_folder("root")
@property
def size(self) -> int:
"""
Calculate and return the total size of all folders in the file system.
:return: The sum of the sizes of all folders in the file system.
"""
return sum(folder.size for folder in self.folders.values())
def show(self, markdown: bool = False, full: bool = False):
"""
Prints a table of the FileSystem, displaying either just folders or full files.
:param markdown: Flag indicating if output should be in markdown format.
:param full: Flag indicating if to show full files.
"""
headers = ["Folder", "Size"]
if full:
headers[0] = "File Path"
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} File System"
for folder in self.folders.values():
if not full:
table.add_row([folder.name, folder.size_str])
else:
for file in folder.files.values():
table.add_row([file.path, file.size_str])
if full:
print(table.get_string(sortby="File Path"))
else:
print(table.get_string(sortby="Folder"))
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()}
return state
def create_folder(self, folder_name: str) -> Folder:
"""
Creates a Folder and adds it to the list of folders.
:param folder_name: The name of the folder.
"""
# check if folder with name already exists
if self.get_folder(folder_name):
raise Exception(f"Cannot create folder as it already exists: {folder_name}")
folder = Folder(name=folder_name, fs=self)
self.folders[folder.uuid] = folder
self._folders_by_name[folder.name] = folder
self.sys_log.info(f"Created folder /{folder.name}")
return folder
def delete_folder(self, folder_name: str):
"""
Deletes a folder, removes it from the folders list and removes any child folders and files.
:param folder_name: The name of the folder.
"""
if folder_name == "root":
self.sys_log.warning("Cannot delete the root folder.")
return
folder = self._folders_by_name.get(folder_name)
if folder:
for file in folder.files.values():
self.delete_file(file)
self.folders.pop(folder.uuid)
self._folders_by_name.pop(folder.name)
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
else:
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
def create_file(
self,
file_name: str,
size: Optional[float] = None,
file_type: Optional[FileSystemFileType] = None,
folder: Optional[FileSystemFolder] = None,
folder_uuid: Optional[str] = None,
) -> FileSystemFile:
size: Optional[int] = None,
file_type: Optional[FileType] = None,
folder_name: Optional[str] = None,
real: bool = False,
) -> File:
"""
Creates a FileSystemFile and adds it to the list of files.
Creates a File and adds it to the list of files.
If no size or file_type are provided, one will be chosen randomly.
If no folder_uuid or folder is provided, a new folder will be created.
:param: file_name: The file name
:type: file_name: str
:param: size: The size the file takes on disk.
:type: size: Optional[float]
:param: file_type: The type of the file
:type: Optional[FileSystemFileType]
:param: folder: The folder to add the file to
:type: folder: Optional[FileSystemFolder]
:param: folder_uuid: The uuid of the folder to add the file to
:type: folder_uuid: Optional[str]
:param file_name: The file name.
:param size: The size the file takes on disk in bytes.
:param file_type: The type of the file.
:param folder_name: The folder to add the file to.
:param real: "Indicates whether the File is actually a real file in the Node sim fs output."
"""
file = None
folder = None
if file_type is None:
file_type = self.get_random_file_type()
# if no folder uuid provided, create a folder and add file to it
if folder_uuid is not None:
# otherwise check for existence and add file
folder = self.get_folder_by_id(folder_uuid)
if folder is not None:
if folder_name:
# check if file with name already exists
if folder.get_file_by_name(file_name):
raise Exception(f'File with name "{file_name}" already exists.')
file = FileSystemFile(name=file_name, size=size, file_type=file_type)
folder.add_file(file=file)
folder = self._folders_by_name.get(folder_name)
# If not then create it
if not folder:
folder = self.create_folder(folder_name)
else:
# check if a "root" folder exists
folder = self.get_folder_by_name("root")
if folder is None:
# create a root folder
folder = FileSystemFolder(name="root")
# Use root folder if folder_name not supplied
folder = self._folders_by_name["root"]
# add file to root folder
file = FileSystemFile(name=file_name, size=size, file_type=file_type)
folder.add_file(file)
self.folders[folder.uuid] = folder
# Create the file and add it to the folder
file = File(
name=file_name,
sim_size=size,
file_type=file_type,
folder=folder,
real=real,
sim_path=self.sim_root if real else None,
)
folder.add_file(file)
self.sys_log.info(f"Created file /{file.path}")
return file
def create_folder(
self,
folder_name: str,
) -> FileSystemFolder:
def get_file(self, folder_name: str, file_name: str) -> Optional[File]:
"""
Creates a FileSystemFolder and adds it to the list of folders.
Retrieve a file by its name from a specific folder.
:param: folder_name: The name of the folder
:type: folder_name: str
:param folder_name: The name of the folder where the file resides.
:param file_name: The name of the file to be retrieved, including its extension.
:return: An instance of File if it exists, otherwise `None`.
"""
# check if folder with name already exists
if self.get_folder_by_name(folder_name):
raise Exception(f'Folder with name "{folder_name}" already exists.')
folder = self.get_folder(folder_name)
if folder:
return folder.get_file(file_name)
self.fs.sys_log.info(f"file not found /{folder_name}/{file_name}")
folder = FileSystemFolder(name=folder_name)
self.folders[folder.uuid] = folder
return folder
def delete_file(self, file: Optional[FileSystemFile] = None):
def delete_file(self, folder_name: str, file_name: str):
"""
Deletes a file and removes it from the files list.
Delete a file by its name from a specific folder.
:param file: The file to delete
:type file: Optional[FileSystemFile]
:param folder_name: The name of the folder containing the file.
:param file_name: The name of the file to be deleted, including its extension.
"""
# iterate through folders to delete the item with the matching uuid
for key in self.folders:
self.get_folder_by_id(key).remove_file(file)
folder = self.get_folder(folder_name)
if folder:
file = folder.get_file(file_name)
if file:
folder.remove_file(file)
self.sys_log.info(f"Deleted file /{file.path}")
def delete_folder(self, folder: FileSystemFolder):
def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
"""
Deletes a folder, removes it from the folders list and removes any child folders and files.
Move a file from one folder to another.
:param folder: The folder to remove
:type folder: FileSystemFolder
:param src_folder_name: The name of the source folder containing the file.
:param src_file_name: The name of the file to be moved.
:param dst_folder_name: The name of the destination folder.
"""
if folder is None or not isinstance(folder, FileSystemFolder):
raise Exception(f"Invalid folder: {folder}")
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
src_folder = file.folder
if self.folders.get(folder.uuid):
del self.folders[folder.uuid]
# remove file from src
src_folder.remove_file(file)
dst_folder = self.get_folder(folder_name=dst_folder_name)
if not dst_folder:
dst_folder = self.create_folder(dst_folder_name)
# add file to dst
dst_folder.add_file(file)
if file.real:
old_sim_path = file.sim_path
file.sim_path = file.folder.fs.sim_root / file.path
file.sim_path.parent.mkdir(exist_ok=True)
shutil.move(old_sim_path, file.sim_path)
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
"""
Copy a file from one folder to another.
:param src_folder_name: The name of the source folder containing the file.
:param src_file_name: The name of the file to be copied.
:param dst_folder_name: The name of the destination folder.
"""
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
dst_folder = self.get_folder(folder_name=dst_folder_name)
if not dst_folder:
dst_folder = self.create_folder(dst_folder_name)
new_file = file.make_copy(dst_folder=dst_folder)
dst_folder.add_file(new_file)
if file.real:
new_file.sim_path.parent.mkdir(exist_ok=True)
shutil.copy2(file.sim_path, new_file.sim_path)
def get_folder(self, folder_name: str) -> Optional[Folder]:
"""
Get a folder by its name if it exists.
:param folder_name: The folder name.
:return: The matching Folder.
"""
return self._folders_by_name.get(folder_name)
def get_folder_by_id(self, folder_uuid: str) -> Optional[Folder]:
"""
Get a folder by its uuid if it exists.
:param folder_uuid: The folder uuid.
:return: The matching Folder.
"""
return self.folders.get(folder_uuid)
class Folder(FileSystemItemABC):
"""Simulation Folder."""
fs: FileSystem
"The FileSystem the Folder is in."
files: Dict[str, File] = {}
"Files stored in the folder."
_files_by_name: Dict[str, File] = {}
"Files by their name as <file name>.<file type>."
is_quarantined: bool = False
"Flag that marks the folder as quarantined if true."
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
:return: Current state of this object and child objects.
"""
state = super().describe_state()
state["files"] = {file.name: file.describe_state() for uuid, file in self.files.items()}
state["is_quarantined"] = self.is_quarantined
return state
def show(self, markdown: bool = False):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["File", "Size"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.fs.sys_log.hostname} File System Folder ({self.name})"
for file in self.files.values():
table.add_row([file.name, file.size_str])
print(table.get_string(sortby="File"))
@property
def size(self) -> int:
"""
Calculate and return the total size of all files in the folder.
:return: The total size of all files in the folder. If no files exist or all have `None`
size, returns 0.
"""
return sum(file.size for file in self.files.values() if file.size is not None)
def get_file(self, file_name: str) -> Optional[File]:
"""
Get a file by its name.
File name must be the filename and prefix, like 'memo.docx'.
:param file_name: The file name.
:return: The matching File.
"""
# TODO: Increment read count?
return self._files_by_name.get(file_name)
def get_file_by_id(self, file_uuid: str) -> File:
"""
Get a file by its uuid.
:param file_uuid: The file uuid.
:return: The matching File.
"""
return self.files.get(file_uuid)
def add_file(self, file: File):
"""
Adds a file to the folder.
:param File file: The File object to be added to the folder.
:raises Exception: If the provided `file` parameter is None or not an instance of the
`File` class.
"""
if file is None or not isinstance(file, File):
raise Exception(f"Invalid file: {file}")
# check if file with id already exists in folder
if file.uuid in self.files:
_LOGGER.debug(f"File with id {file.uuid} already exists in folder")
else:
_LOGGER.debug(f"File with UUID {folder.uuid} was not found.")
# add to list
self.files[file.uuid] = file
self._files_by_name[file.name] = file
file.folder = self
def move_file(self, file: FileSystemFile, src_folder: FileSystemFolder, target_folder: FileSystemFolder):
def remove_file(self, file: Optional[File]):
"""
Moves a file from one folder to another.
Removes a file from the folder list.
can provide
The method can take a File object or a file id.
:param: file: The file to move
:type: file: FileSystemFile
:param: src_folder: The folder where the file is located
:type: FileSystemFolder
:param: target_folder: The folder where the file should be moved to
:type: FileSystemFolder
:param file: The file to remove
"""
# check that the folders exist
if src_folder is None:
raise Exception("Source folder not provided")
if file is None or not isinstance(file, File):
raise Exception(f"Invalid file: {file}")
if target_folder is None:
raise Exception("Target folder not provided")
if self.files.get(file.uuid):
self.files.pop(file.uuid)
self._files_by_name.pop(file.name)
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
if file is None:
raise Exception("File to be moved is None")
def quarantine(self):
"""Quarantines the File System Folder."""
if not self.is_quarantined:
self.is_quarantined = True
self.fs.sys_log.info(f"Quarantined folder ./{self.name}")
# check if file with name already exists
if target_folder.get_file_by_name(file.name):
raise Exception(f'Folder with name "{file.name}" already exists.')
def unquarantine(self):
"""Unquarantine of the File System Folder."""
if self.is_quarantined:
self.is_quarantined = False
self.fs.sys_log.info(f"Quarantined folder ./{self.name}")
# remove file from src
src_folder.remove_file(file)
def quarantine_status(self) -> bool:
"""Returns true if the folder is being quarantined."""
return self.is_quarantined
# add file to target
target_folder.add_file(file)
def copy_file(self, file: FileSystemFile, src_folder: FileSystemFolder, target_folder: FileSystemFolder):
class File(FileSystemItemABC):
"""
Class representing a file in the simulation.
:ivar Folder folder: The folder in which the file resides.
:ivar FileType file_type: The type of the file.
:ivar Optional[int] sim_size: The simulated file size.
:ivar bool real: Indicates if the file is actually a real file in the Node sim fs output.
:ivar Optional[Path] sim_path: The path if the file is real.
"""
folder: Folder
"The Folder the File is in."
file_type: FileType
"The type of File."
sim_size: Optional[int] = None
"The simulated file size."
real: bool = False
"Indicates whether the File is actually a real file in the Node sim fs output."
sim_path: Optional[Path] = None
"The Path if real is True."
def __init__(self, **kwargs):
"""
Copies a file from one folder to another.
Initialise File class.
can provide
:param: file: The file to move
:type: file: FileSystemFile
:param: src_folder: The folder where the file is located
:type: FileSystemFolder
:param: target_folder: The folder where the file should be moved to
:type: FileSystemFolder
:param name: The name of the file.
:param file_type: The FileType of the file
:param size: The size of the FileSystemItemABC
"""
if src_folder is None:
raise Exception("Source folder not provided")
has_extension = "." in kwargs["name"]
if target_folder is None:
raise Exception("Target folder not provided")
# Attempt to use the file type extension to set/override the FileType
if has_extension:
extension = kwargs["name"].split(".")[-1]
kwargs["file_type"] = get_file_type_from_extension(extension)
else:
# If the file name does not have a extension, override file type to FileType.UNKNOWN
if not kwargs["file_type"]:
kwargs["file_type"] = FileType.UNKNOWN
if kwargs["file_type"] != FileType.UNKNOWN:
kwargs["name"] = f"{kwargs['name']}.{kwargs['file_type'].name.lower()}"
if file is None:
raise Exception("File to be moved is None")
# set random file size if none provided
if not kwargs.get("sim_size"):
kwargs["sim_size"] = kwargs["file_type"].default_size
super().__init__(**kwargs)
if self.real:
self.sim_path = self.folder.fs.sim_root / self.path
if not self.sim_path.exists():
self.sim_path.parent.mkdir(exist_ok=True, parents=True)
with open(self.sim_path, mode="a"):
pass
# check if file with name already exists
if target_folder.get_file_by_name(file.name):
raise Exception(f'Folder with name "{file.name}" already exists.')
# add file to target
target_folder.add_file(file)
def get_file_by_id(self, file_id: str) -> FileSystemFile:
"""Checks if the file exists in any file system folders."""
for key in self.folders:
file = self.folders[key].get_file_by_id(file_id=file_id)
if file is not None:
return file
def get_folder_by_name(self, folder_name: str) -> Optional[FileSystemFolder]:
def make_copy(self, dst_folder: Folder) -> File:
"""
Returns a the first folder with a matching name.
Create a copy of the current File object in the given destination folder.
:return: Returns the first FileSydtemFolder with a matching name
:param Folder dst_folder: The destination folder for the copied file.
:return: A new File object that is a copy of the current file.
"""
matching_folder = None
for key in self.folders:
if self.folders[key].name == folder_name:
matching_folder = self.folders[key]
break
return matching_folder
return File(folder=dst_folder, **self.model_dump(exclude={"uuid", "folder", "sim_path"}))
def get_folder_by_id(self, folder_id: str) -> FileSystemFolder:
@property
def path(self) -> str:
"""
Checks if the folder exists.
Get the path of the file in the file system.
:param: folder_id: The id of the folder to find
:type: folder_id: str
:return: The full path of the file.
"""
return self.folders[folder_id]
return f"{self.folder.name}/{self.name}"
def get_random_file_type(self) -> FileSystemFileType:
@property
def size(self) -> int:
"""
Returns a random FileSystemFileTypeEnum.
Get the size of the file in bytes.
:return: A random file type Enum
:return: The size of the file in bytes.
"""
return choice(list(FileSystemFileType))
if self.real:
return os.path.getsize(self.sim_path)
return self.sim_size
def describe_state(self) -> Dict:
"""Produce a dictionary describing the current state of this object."""
state = super().describe_state()
state["size"] = self.size
state["file_type"] = self.file_type.name
return state

View File

@@ -1,55 +0,0 @@
from random import choice
from typing import Dict
from primaite.simulator.file_system.file_system_file_type import file_type_sizes_KB, FileSystemFileType
from primaite.simulator.file_system.file_system_item_abc import FileSystemItem
class FileSystemFile(FileSystemItem):
"""Class that represents a file in the simulation."""
file_type: FileSystemFileType = None
"""The type of the FileSystemFile"""
def __init__(self, **kwargs):
"""
Initialise FileSystemFile class.
:param name: The name of the file.
:type name: str
:param file_type: The FileSystemFileType of the file
:type file_type: Optional[FileSystemFileType]
:param size: The size of the FileSystemItem
:type size: Optional[float]
"""
# set random file type if none provided
# set random file type if none provided
if kwargs.get("file_type") is None:
kwargs["file_type"] = choice(list(FileSystemFileType))
# set random file size if none provided
if kwargs.get("size") is None:
kwargs["size"] = file_type_sizes_KB[kwargs["file_type"]]
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update(
{
"uuid": self.uuid,
"file_type": self.file_type.name,
}
)
return state

View File

@@ -1,132 +0,0 @@
from enum import Enum
class FileSystemFileType(str, Enum):
"""An enumeration of common file types."""
UNKNOWN = 0
"Unknown file type."
# Text formats
TXT = 1
"Plain text file."
DOC = 2
"Microsoft Word document (.doc)"
DOCX = 3
"Microsoft Word document (.docx)"
PDF = 4
"Portable Document Format."
HTML = 5
"HyperText Markup Language file."
XML = 6
"Extensible Markup Language file."
CSV = 7
"Comma-Separated Values file."
# Spreadsheet formats
XLS = 8
"Microsoft Excel file (.xls)"
XLSX = 9
"Microsoft Excel file (.xlsx)"
# Image formats
JPEG = 10
"JPEG image file."
PNG = 11
"PNG image file."
GIF = 12
"GIF image file."
BMP = 13
"Bitmap image file."
# Audio formats
MP3 = 14
"MP3 audio file."
WAV = 15
"WAV audio file."
# Video formats
MP4 = 16
"MP4 video file."
AVI = 17
"AVI video file."
MKV = 18
"MKV video file."
FLV = 19
"FLV video file."
# Presentation formats
PPT = 20
"Microsoft PowerPoint file (.ppt)"
PPTX = 21
"Microsoft PowerPoint file (.pptx)"
# Web formats
JS = 22
"JavaScript file."
CSS = 23
"Cascading Style Sheets file."
# Programming languages
PY = 24
"Python script file."
C = 25
"C source code file."
CPP = 26
"C++ source code file."
JAVA = 27
"Java source code file."
# Compressed file types
RAR = 28
"RAR archive file."
ZIP = 29
"ZIP archive file."
TAR = 30
"TAR archive file."
GZ = 31
"Gzip compressed file."
# Database file types
MDF = 32
"MS SQL Server primary database file"
NDF = 33
"MS SQL Server secondary database file"
LDF = 34
"MS SQL Server transaction log"
file_type_sizes_KB = {
FileSystemFileType.UNKNOWN: 0,
FileSystemFileType.TXT: 4,
FileSystemFileType.DOC: 50,
FileSystemFileType.DOCX: 30,
FileSystemFileType.PDF: 100,
FileSystemFileType.HTML: 15,
FileSystemFileType.XML: 10,
FileSystemFileType.CSV: 15,
FileSystemFileType.XLS: 100,
FileSystemFileType.XLSX: 25,
FileSystemFileType.JPEG: 100,
FileSystemFileType.PNG: 40,
FileSystemFileType.GIF: 30,
FileSystemFileType.BMP: 300,
FileSystemFileType.MP3: 5000,
FileSystemFileType.WAV: 25000,
FileSystemFileType.MP4: 25000,
FileSystemFileType.AVI: 50000,
FileSystemFileType.MKV: 50000,
FileSystemFileType.FLV: 15000,
FileSystemFileType.PPT: 200,
FileSystemFileType.PPTX: 100,
FileSystemFileType.JS: 10,
FileSystemFileType.CSS: 5,
FileSystemFileType.PY: 5,
FileSystemFileType.C: 5,
FileSystemFileType.CPP: 10,
FileSystemFileType.JAVA: 10,
FileSystemFileType.RAR: 1000,
FileSystemFileType.ZIP: 1000,
FileSystemFileType.TAR: 1000,
FileSystemFileType.GZ: 800,
}

View File

@@ -1,87 +0,0 @@
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_item_abc import FileSystemItem
_LOGGER = getLogger(__name__)
class FileSystemFolder(FileSystemItem):
"""Simulation FileSystemFolder."""
files: Dict[str, FileSystemFile] = {}
"""List of files stored in the folder."""
is_quarantined: bool = False
"""Flag that marks the folder as quarantined if true."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update(
{
"files": {uuid: file.describe_state() for uuid, file in self.files.items()},
"is_quarantined": self.is_quarantined,
}
)
return state
def get_file_by_id(self, file_id: str) -> FileSystemFile:
"""Return a FileSystemFile with the matching id."""
return self.files.get(file_id)
def get_file_by_name(self, file_name: str) -> FileSystemFile:
"""Return a FileSystemFile with the matching id."""
return next((f for f in list(self.files) if f.name == file_name), None)
def add_file(self, file: FileSystemFile):
"""Adds a file to the folder list."""
if file is None or not isinstance(file, FileSystemFile):
raise Exception(f"Invalid file: {file}")
# check if file with id already exists in folder
if file.uuid in self.files:
_LOGGER.debug(f"File with id {file.uuid} already exists in folder")
else:
# add to list
self.files[file.uuid] = file
self.size += file.size
def remove_file(self, file: Optional[FileSystemFile]):
"""
Removes a file from the folder list.
The method can take a FileSystemFile object or a file id.
:param: file: The file to remove
:type: Optional[FileSystemFile]
"""
if file is None or not isinstance(file, FileSystemFile):
raise Exception(f"Invalid file: {file}")
if self.files.get(file.uuid):
del self.files[file.uuid]
self.size -= file.size
else:
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
def quarantine(self):
"""Quarantines the File System Folder."""
self.is_quarantined = True
def end_quarantine(self):
"""Ends the quarantine of the File System Folder."""
self.is_quarantined = False
def quarantine_status(self) -> bool:
"""Returns true if the folder is being quarantined."""
return self.is_quarantined

View File

@@ -1,31 +0,0 @@
from typing import Dict
from primaite.simulator.core import SimComponent
class FileSystemItem(SimComponent):
"""Abstract base class for FileSystemItems used in the file system simulation."""
name: str
"""The name of the FileSystemItem."""
size: float = 0
"""The size the item takes up on disk."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
state.update(
{
"name": self.name,
"size": self.size,
}
)
return state

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
from enum import Enum
from random import choice
from typing import Any
class FileType(Enum):
"""An enumeration of common file types."""
UNKNOWN = 0
"Unknown file type."
# Text formats
TXT = 1
"Plain text file."
DOC = 2
"Microsoft Word document (.doc)"
DOCX = 3
"Microsoft Word document (.docx)"
PDF = 4
"Portable Document Format."
HTML = 5
"HyperText Markup Language file."
XML = 6
"Extensible Markup Language file."
CSV = 7
"Comma-Separated Values file."
# Spreadsheet formats
XLS = 8
"Microsoft Excel file (.xls)"
XLSX = 9
"Microsoft Excel file (.xlsx)"
# Image formats
JPEG = 10
"JPEG image file."
PNG = 11
"PNG image file."
GIF = 12
"GIF image file."
BMP = 13
"Bitmap image file."
# Audio formats
MP3 = 14
"MP3 audio file."
WAV = 15
"WAV audio file."
# Video formats
MP4 = 16
"MP4 video file."
AVI = 17
"AVI video file."
MKV = 18
"MKV video file."
FLV = 19
"FLV video file."
# Presentation formats
PPT = 20
"Microsoft PowerPoint file (.ppt)"
PPTX = 21
"Microsoft PowerPoint file (.pptx)"
# Web formats
JS = 22
"JavaScript file."
CSS = 23
"Cascading Style Sheets file."
# Programming languages
PY = 24
"Python script file."
C = 25
"C source code file."
CPP = 26
"C++ source code file."
JAVA = 27
"Java source code file."
# Compressed file types
RAR = 28
"RAR archive file."
ZIP = 29
"ZIP archive file."
TAR = 30
"TAR archive file."
GZ = 31
"Gzip compressed file."
# Database file types
DB = 32
"Generic DB file. Used by sqlite3."
@classmethod
def _missing_(cls, value: Any) -> FileType:
return cls.UNKNOWN
@classmethod
def random(cls) -> FileType:
"""
Returns a random FileType.
:return: A random FileType.
"""
return choice(list(FileType))
@property
def default_size(self) -> int:
"""
Get the default size of the FileType in bytes.
Returns 0 if a default size does not exist.
"""
size = file_type_sizes_bytes[self]
return size if size else 0
def get_file_type_from_extension(file_type_extension: str) -> FileType:
"""
Get a FileType from a file type extension.
If a matching extension does not exist, FileType.UNKNOWN is returned.
:param file_type_extension: A file type extension.
:return: A file type extension.
"""
try:
return FileType[file_type_extension.upper()]
except KeyError:
return FileType.UNKNOWN
file_type_sizes_bytes = {
FileType.UNKNOWN: 0,
FileType.TXT: 4096,
FileType.DOC: 51200,
FileType.DOCX: 30720,
FileType.PDF: 102400,
FileType.HTML: 15360,
FileType.XML: 10240,
FileType.CSV: 15360,
FileType.XLS: 102400,
FileType.XLSX: 25600,
FileType.JPEG: 102400,
FileType.PNG: 40960,
FileType.GIF: 30720,
FileType.BMP: 307200,
FileType.MP3: 5120000,
FileType.WAV: 25600000,
FileType.MP4: 25600000,
FileType.AVI: 51200000,
FileType.MKV: 51200000,
FileType.FLV: 15360000,
FileType.PPT: 204800,
FileType.PPTX: 102400,
FileType.JS: 10240,
FileType.CSS: 5120,
FileType.PY: 5120,
FileType.C: 5120,
FileType.CPP: 10240,
FileType.JAVA: 10240,
FileType.RAR: 1024000,
FileType.ZIP: 1024000,
FileType.TAR: 1024000,
FileType.GZ: 819200,
FileType.DB: 15360000,
}

View File

@@ -6,7 +6,7 @@ from networkx import MultiGraph
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.core import Action, ActionManager, AllowAllValidator, SimComponent
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
@@ -29,6 +29,8 @@ class Network(SimComponent):
nodes: Dict[str, Node] = {}
links: Dict[str, Link] = {}
_node_id_map: Dict[int, Node] = {}
_link_id_map: Dict[int, Node] = {}
def __init__(self, **kwargs):
"""
@@ -47,7 +49,7 @@ class Network(SimComponent):
am.add_action(
"node",
Action(
func = self._node_action_manager
func=self._node_action_manager
# func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
),
)
@@ -161,8 +163,8 @@ class Network(SimComponent):
state = super().describe_state()
state.update(
{
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
"nodes": {i for i, node in self._node_id_map.items()},
"links": {i: link.describe_state() for i, link in self._link_id_map.items()},
}
)
return state
@@ -179,10 +181,11 @@ class Network(SimComponent):
_LOGGER.warning(f"Can't add node {node.uuid}. It is already in the network.")
return
self.nodes[node.uuid] = node
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
self._node_action_manager.add_action(name = node.uuid, action = Action(func=node._action_manager))
self._node_action_manager.add_action(name=node.uuid, action=Action(func=node._action_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -210,9 +213,13 @@ class Network(SimComponent):
_LOGGER.warning(f"Can't remove node {node.uuid}. It's not in the network.")
return
self.nodes.pop(node.uuid)
for i, _node in self._node_id_map.items():
if node == _node:
self._node_id_map.pop(i)
break
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_action_manager.remove_action(name = node.uuid)
self._node_action_manager.remove_action(name=node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""
@@ -237,9 +244,10 @@ class Network(SimComponent):
return
link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, **kwargs)
self.links[link.uuid] = link
self._link_id_map[len(self.links)] = link
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
link.parent = self
_LOGGER.info(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
def remove_link(self, link: Link) -> None:
"""Disconnect a link from the network.
@@ -250,6 +258,10 @@ class Network(SimComponent):
link.endpoint_a.disconnect_link()
link.endpoint_b.disconnect_link()
self.links.pop(link.uuid)
for i, _link in self._link_id_map.items():
if link == _link:
self._link_id_map.pop(i)
break
link.parent = None
_LOGGER.info(f"Removed link {link.uuid} from network {self.uuid}.")

View File

@@ -4,12 +4,14 @@ import re
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
@@ -87,8 +89,6 @@ class NIC(SimComponent):
"The Maximum Transmission Unit (MTU) of the NIC in Bytes. Default is 1500 B"
wake_on_lan: bool = False
"Indicates if the NIC supports Wake-on-LAN functionality."
dns_servers: List[IPv4Address] = []
"List of IP addresses of DNS servers used for name resolution."
connected_node: Optional[Node] = None
"The Node to which the NIC is connected."
connected_link: Optional[Link] = None
@@ -191,7 +191,7 @@ class NIC(SimComponent):
if self.connected_node:
self.connected_node.sys_log.info(f"NIC {self} disabled")
else:
_LOGGER.info(f"NIC {self} disabled")
_LOGGER.debug(f"NIC {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
@@ -213,7 +213,7 @@ class NIC(SimComponent):
# TODO: Inform the Node that a link has been connected
self.connected_link = link
self.enable()
_LOGGER.info(f"NIC {self} connected to Link {link}")
_LOGGER.debug(f"NIC {self} connected to Link {link}")
def disconnect_link(self):
"""Disconnect the NIC from the connected Link."""
@@ -356,7 +356,7 @@ class SwitchPort(SimComponent):
if self.connected_node:
self.connected_node.sys_log.info(f"SwitchPort {self} disabled")
else:
_LOGGER.info(f"SwitchPort {self} disabled")
_LOGGER.debug(f"SwitchPort {self} disabled")
if self.connected_link:
self.connected_link.endpoint_down()
@@ -376,7 +376,7 @@ class SwitchPort(SimComponent):
# TODO: Inform the Switch that a link has been connected
self.connected_link = link
_LOGGER.info(f"SwitchPort {self} connected to Link {link}")
_LOGGER.debug(f"SwitchPort {self} connected to Link {link}")
self.enable()
def disconnect_link(self):
@@ -411,7 +411,8 @@ class SwitchPort(SimComponent):
if self.enabled:
frame.decrement_ttl()
self.pcap.capture(frame)
self.connected_node.forward_frame(frame=frame, incoming_port=self)
connected_node: Node = self.connected_node
connected_node.forward_frame(frame=frame, incoming_port=self)
return True
return False
@@ -482,13 +483,13 @@ class Link(SimComponent):
def endpoint_up(self):
"""Let the Link know and endpoint has been brought up."""
if self.is_up:
_LOGGER.info(f"Link {self} up")
_LOGGER.debug(f"Link {self} up")
def endpoint_down(self):
"""Let the Link know and endpoint has been brought down."""
if not self.is_up:
self.current_load = 0.0
_LOGGER.info(f"Link {self} down")
_LOGGER.debug(f"Link {self} down")
@property
def is_up(self) -> bool:
@@ -515,7 +516,7 @@ class Link(SimComponent):
"""
can_transmit = self._can_transmit(frame)
if not can_transmit:
_LOGGER.info(f"Cannot transmit frame as {self} is at capacity")
_LOGGER.debug(f"Cannot transmit frame as {self} is at capacity")
return False
receiver = self.endpoint_a
@@ -527,7 +528,7 @@ class Link(SimComponent):
# Frame transmitted successfully
# Load the frame size on the link
self.current_load += frame_size
_LOGGER.info(
_LOGGER.debug(
f"Added {frame_size:.3f} Mbits to {self}, current load {self.current_load:.3f} Mbits "
f"({self.current_load_percent})"
)
@@ -886,6 +887,8 @@ class Node(SimComponent):
"The NICs on the node."
ethernet_port: Dict[int, NIC] = {}
"The NICs on the node by port id."
dns_server: Optional[IPv4Address] = None
"List of IP addresses of DNS servers used for name resolution."
accounts: Dict[str, Account] = {}
"All accounts on the node."
@@ -897,6 +900,8 @@ class Node(SimComponent):
"All processes on the node."
file_system: FileSystem
"The nodes file system."
root: Path
"Root directory for simulation output."
sys_log: SysLog
arp: ARPCache
icmp: ICMP
@@ -924,14 +929,20 @@ class Node(SimComponent):
kwargs["icmp"] = ICMP(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):
kwargs["software_manager"] = SoftwareManager(
sys_log=kwargs.get("sys_log"), session_manager=kwargs.get("session_manager")
sys_log=kwargs.get("sys_log"),
session_manager=kwargs.get("session_manager"),
file_system=kwargs.get("file_system"),
dns_server=kwargs.get("dns_server"),
)
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem()
super().__init__(**kwargs)
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
def _init_action_manager(self) -> ActionManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
@@ -975,7 +986,25 @@ class Node(SimComponent):
)
return state
def show(self, markdown: bool = False):
def show(self, markdown: bool = False, component: Literal["NIC", "OPEN_PORTS"] = "NIC"):
"""A multi-use .show function that accepts either NIC or OPEN_PORTS."""
if component == "NIC":
self._show_nic(markdown)
elif component == "OPEN_PORTS":
self._show_open_ports(markdown)
def _show_open_ports(self, markdown: bool = False):
"""Prints a table of the open ports on the Node."""
table = PrettyTable(["Port", "Name"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.hostname} Open Ports"
for port in self.software_manager.get_open_ports():
table.add_row([port.value, port.name])
print(table)
def _show_nic(self, markdown: bool = False):
"""Prints a table of the NICs on the Node."""
table = PrettyTable(["Port", "MAC Address", "Address", "Speed", "Status"])
if markdown:
@@ -1066,29 +1095,30 @@ class Node(SimComponent):
:param pings: The number of pings to attempt, default is 4.
:return: True if the ping is successful, otherwise False.
"""
if not isinstance(target_ip_address, IPv4Address):
target_ip_address = IPv4Address(target_ip_address)
if target_ip_address.is_loopback:
self.sys_log.info("Pinging loopback address")
return any(nic.enabled for nic in self.nics.values())
if self.operating_state == NodeOperatingState.ON:
self.sys_log.info(f"Pinging {target_ip_address}:")
sequence, identifier = 0, None
while sequence < pings:
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
request_replies = self.icmp.request_replies.get(identifier)
passed = request_replies == pings
if request_replies:
self.icmp.request_replies.pop(identifier)
else:
request_replies = 0
self.sys_log.info(
f"Ping statistics for {target_ip_address}: "
f"Packets: Sent = {pings}, "
f"Received = {request_replies}, "
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
)
return passed
if not isinstance(target_ip_address, IPv4Address):
target_ip_address = IPv4Address(target_ip_address)
if target_ip_address.is_loopback:
self.sys_log.info("Pinging loopback address")
return any(nic.enabled for nic in self.nics.values())
if self.operating_state == NodeOperatingState.ON:
self.sys_log.info(f"Pinging {target_ip_address}:")
sequence, identifier = 0, None
while sequence < pings:
sequence, identifier = self.icmp.ping(target_ip_address, sequence, identifier, pings)
request_replies = self.icmp.request_replies.get(identifier)
passed = request_replies == pings
if request_replies:
self.icmp.request_replies.pop(identifier)
else:
request_replies = 0
self.sys_log.info(
f"Ping statistics for {target_ip_address}: "
f"Packets: Sent = {pings}, "
f"Received = {request_replies}, "
f"Lost = {pings-request_replies} ({(pings-request_replies)/pings*100}% loss)"
)
return passed
return False
def send_frame(self, frame: Frame):
@@ -1097,7 +1127,8 @@ class Node(SimComponent):
:param frame: The Frame to be sent.
"""
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
if self.operating_state == NodeOperatingState.ON:
nic: NIC = self._get_arp_cache_nic(frame.ip.dst_ip_address)
nic.send_frame(frame)
def receive_frame(self, frame: Frame, from_nic: NIC):
@@ -1110,18 +1141,27 @@ class Node(SimComponent):
:param frame: The Frame being received.
:param from_nic: The NIC that received the frame.
"""
if frame.ip:
if frame.ip.src_ip_address in self.arp:
self.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
)
if frame.ip.protocol == IPProtocol.TCP:
if frame.tcp.src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
elif frame.ip.protocol == IPProtocol.UDP:
pass
elif frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
if self.operating_state == NodeOperatingState.ON:
if frame.ip:
if frame.ip.src_ip_address in self.arp:
self.arp.add_arp_cache_entry(
ip_address=frame.ip.src_ip_address, mac_address=frame.ethernet.src_mac_addr, nic=from_nic
)
if frame.ip.protocol == IPProtocol.ICMP:
self.icmp.process_icmp(frame=frame, from_nic=from_nic)
return
# Check if the destination port is open on the Node
if frame.tcp.dst_port in self.software_manager.get_open_ports():
# accept thr frame as the port is open
if frame.tcp.src_port == Port.ARP:
self.arp.process_arp_packet(from_nic=from_nic, arp_packet=frame.arp)
else:
self.session_manager.receive_frame(frame)
else:
# denied as port closed
self.sys_log.info(f"Ignoring frame for port {frame.tcp.dst_port.value} from {frame.ip.src_ip_address}")
# TODO: do we need to do anything more here?
pass
def install_service(self, service: Service) -> None:
"""

View File

@@ -1,3 +1,5 @@
from ipaddress import IPv4Address
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC
from primaite.simulator.network.hardware.nodes.computer import Computer
@@ -6,6 +8,11 @@ from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def client_server_routed() -> Network:
@@ -121,16 +128,33 @@ def arcd_uc2_network() -> Network:
# Client 1
client_1 = Computer(
hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
)
client_1.power_on()
client_1.software_manager.install(DNSClient)
client_1_dns_client_service: DNSServer = client_1.software_manager.software["DNSClient"] # noqa
client_1_dns_client_service.start()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
# Client 2
client_2 = Computer(
hostname="client_2", ip_address="192.168.10.22", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
dns_server=IPv4Address("192.168.1.10"),
)
client_2.power_on()
client_2.software_manager.install(DNSClient)
client_2_dns_client_service: DNSServer = client_2.software_manager.software["DNSClient"] # noqa
client_2_dns_client_service.start()
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller
@@ -141,14 +165,9 @@ def arcd_uc2_network() -> Network:
default_gateway="192.168.1.1",
)
domain_controller.power_on()
network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
domain_controller.software_manager.install(DNSServer)
# Web Server
web_server = Server(
hostname="web_server", ip_address="192.168.1.12", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
web_server.power_on()
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
network.connect(endpoint_b=domain_controller.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
# Database Server
database_server = Server(
@@ -156,13 +175,73 @@ def arcd_uc2_network() -> Network:
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
database_server.power_on()
network.connect(endpoint_b=database_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[3])
ddl = """
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(50) NOT NULL,
email VARCHAR(50) NOT NULL,
age INT,
city VARCHAR(50),
occupation VARCHAR(50)
);"""
user_insert_statements = [
"INSERT INTO user (name, email, age, city, occupation) VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", # noqa
"INSERT INTO user (name, email, age, city, occupation) VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", # noqa
]
database_server.software_manager.install(DatabaseService)
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
database_service.start()
database_service._process_sql(ddl, None) # noqa
for insert_statement in user_insert_statements:
database_service._process_sql(insert_statement, None) # noqa
# Web Server
web_server = Server(
hostname="web_server",
ip_address="192.168.1.12",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
database_client.run()
database_client.connect()
# register the web_server to a domain
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa
dns_server_service.start()
dns_server_service.dns_register("arcd.com", web_server.ip_address)
# Backup Server
backup_server = Server(
hostname="backup_server", ip_address="192.168.1.16", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
hostname="backup_server",
ip_address="192.168.1.16",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
backup_server.power_on()
network.connect(endpoint_b=backup_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[4])
@@ -173,6 +252,7 @@ def arcd_uc2_network() -> Network:
ip_address="192.168.1.110",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
security_suite.power_on()
network.connect(endpoint_b=security_suite.ethernet_port[1], endpoint_a=switch_1.switch_ports[7])
@@ -183,4 +263,12 @@ def arcd_uc2_network() -> Network:
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
# Allow PostgreSQL requests
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
# Allow DNS requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
return network

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Optional
from pydantic import BaseModel
class DNSRequest(BaseModel):
"""Represents a DNS Request packet of a network frame.
:param domain_name_request: Domain Name Request for IP address.
"""
domain_name_request: str
"Domain Name Request for IP address."
class DNSReply(BaseModel):
"""Represents a DNS Reply packet of a network frame.
:param domain_name_ip_address: IP Address of the Domain Name requested.
"""
domain_name_ip_address: Optional[IPv4Address] = None
"IP Address of the Domain Name requested."
class DNSPacket(BaseModel):
"""
Represents the DNS layer of a network frame.
:param dns_request: DNS Request packet sent by DNS Client.
:param dns_reply: DNS Reply packet generated by DNS Server.
:Example:
>>> dns_request = DNSPacket(
... domain_name_request=DNSRequest(domain_name_request="www.google.co.uk"),
... dns_reply=None
... )
>>> dns_response = DNSPacket(
... dns_request=DNSRequest(domain_name_request="www.google.co.uk"),
... dns_reply=DNSReply(domain_name_ip_address=IPv4Address("142.250.179.227"))
... )
"""
dns_request: DNSRequest
"DNS Request packet sent by DNS Client."
dns_reply: Optional[DNSReply] = None
"DNS Reply packet generated by DNS Server."
def generate_reply(self, domain_ip_address: IPv4Address) -> DNSPacket:
"""Generate a new DNSPacket to be sent as a response with a DNS Reply packet which contains the IP address.
:param domain_ip_address: The IP address that was being sought after from the original target domain name.
:return: A new instance of DNSPacket.
"""
self.dns_reply = DNSReply(domain_name_ip_address=domain_ip_address)
return self

View File

@@ -59,6 +59,8 @@ class Port(Enum):
"Alternative port for HTTP (HTTP_ALT) - Often used as an alternative HTTP port for web applications."
HTTPS_ALT = 8443
"Alternative port for HTTPS (HTTPS_ALT) - Used in some configurations for secure web traffic."
POSTGRES_SERVER = 5432
"Postgres SQL Server."
class UDPHeader(BaseModel):

View File

@@ -23,9 +23,9 @@ class Application(IOSoftware):
Applications are user-facing programs that may perform input/output operations.
"""
operating_state: ApplicationOperatingState
operating_state: ApplicationOperatingState = ApplicationOperatingState.CLOSED
"The current operating state of the Application."
execution_control_status: str
execution_control_status: str = "manual"
"Control status of the application's execution. It could be 'manual' or 'automatic'."
num_executions: int = 0
"The number of times the application has been executed. Default is 0."
@@ -53,6 +53,25 @@ class Application(IOSoftware):
)
return state
def run(self) -> None:
"""Open the Application."""
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Running Application {self.name}")
self.operating_state = ApplicationOperatingState.RUNNING
def close(self) -> None:
"""Close the Application."""
if self.operating_state == ApplicationOperatingState.RUNNING:
self.sys_log.info(f"Closed Application{self.name}")
self.operating_state = ApplicationOperatingState.CLOSED
def install(self) -> None:
"""Install Application."""
super().install()
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Installing Application {self.name}")
self.operating_state = ApplicationOperatingState.INSTALLING
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.

View File

@@ -0,0 +1,157 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from uuid import uuid4
from prettytable import PrettyTable
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.software_manager import SoftwareManager
class DatabaseClient(Application):
"""
A DatabaseClient application.
Extends the Application class to provide functionality for connecting, querying, and disconnecting from a
Database Service. It mainly operates over TCP protocol.
:ivar server_ip_address: The IPv4 address of the Database Service server, defaults to None.
"""
server_ip_address: Optional[IPv4Address] = None
server_password: Optional[str] = None
connected: bool = False
_query_success_tracker: Dict[str, bool] = {}
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseClient"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
:return: A dictionary representing the current state.
"""
pass
return super().describe_state()
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
"""
Configure the DatabaseClient to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
"""
self.server_ip_address = server_ip_address
self.server_password = server_password
self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {server_password=}.")
def connect(self) -> bool:
"""Connect to a Database Service."""
if not self.connected and self.operating_state.RUNNING:
return self._connect(self.server_ip_address, self.server_password)
return False
def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
) -> bool:
if is_reattempt:
if self.connected:
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} authorised")
self.server_ip_address = server_ip_address
return self.connected
else:
self.sys_log.info(f"DatabaseClient connected to {server_ip_address} declined")
return False
payload = {"type": "connect_request", "password": password}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
)
return self._connect(server_ip_address, password, True)
def disconnect(self):
"""Disconnect from the Database Service."""
if self.connected and self.operating_state.RUNNING:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
)
self.sys_log.info(f"DatabaseClient disconnected from {self.server_ip_address}")
self.server_ip_address = None
self.connected = False
def _query(self, sql: str, query_id: str, is_reattempt: bool = False) -> bool:
if is_reattempt:
success = self._query_success_tracker.get(query_id)
if success:
return True
return False
else:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "sql", "sql": sql, "uuid": query_id},
dest_ip_address=self.server_ip_address,
dest_port=self.port,
)
return self._query(sql=sql, query_id=query_id, is_reattempt=True)
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
self.operating_state = ApplicationOperatingState.RUNNING
self.connect()
def query(self, sql: str) -> bool:
"""
Send a query to the Database Service.
:param sql: The SQL query.
:return: True if the query was successful, otherwise False.
"""
if self.connected and self.operating_state.RUNNING:
query_id = str(uuid4())
# Initialise the tracker of this ID to False
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
def _print_data(self, data: Dict):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
if data:
table = PrettyTable(list(data.values())[0])
table.align = "l"
table.title = f"{self.sys_log.hostname} Database Client"
for row in data.values():
table.add_row(row.values())
print(table)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
self.connected = payload["response"] == True
elif payload["type"] == "sql":
query_id = payload.get("uuid")
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
self._print_data(payload["data"])
return True

View File

@@ -0,0 +1,54 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite.simulator.system.applications.application import Application
class WebBrowser(Application):
"""
Represents a web browser in the simulation environment.
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
domain_name: str
"The domain name of the webpage."
domain_name_ip_address: Optional[IPv4Address]
"The IP address of the domain name for the webpage."
history: Dict[str]
"A dict that stores all of the previous domain names."
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
self.domain_name = ""
self.domain_name_ip_address = None
self.history = {}
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:return: True if successful, False otherwise.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:return: True if successful, False otherwise.
"""
pass

View File

@@ -1,8 +1,9 @@
import json
import logging
from pathlib import Path
from typing import Optional
from typing import Any, Dict, List, Optional
from primaite.simulator import TEMP_SIM_OUTPUT
from primaite.simulator import SIM_OUTPUT
class _JSONFilter(logging.Filter):
@@ -51,6 +52,18 @@ class PacketCapture:
self.logger.addFilter(_JSONFilter())
def read(self) -> List[Dict[str, Any]]:
"""
Read packet capture logs and return them as a list of dictionaries.
:return: List of frames captured, represented as dictionaries.
"""
frames = []
with open(self._get_log_path(), "r") as file:
while line := file.readline():
frames.append(json.loads(line.rstrip()))
return frames
@property
def _logger_name(self) -> str:
"""Get PCAP the logger name."""
@@ -62,7 +75,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = TEMP_SIM_OUTPUT / self.hostname
root = SIM_OUTPUT / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -1,12 +1,14 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import SimComponent
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
if TYPE_CHECKING:
from primaite.simulator.network.hardware.base import ARPCache
@@ -30,27 +32,23 @@ class Session(SimComponent):
"""
protocol: IPProtocol
src_ip_address: IPv4Address
dst_ip_address: IPv4Address
with_ip_address: IPv4Address
src_port: Optional[Port]
dst_port: Optional[Port]
connected: bool = False
@classmethod
def from_session_key(
cls, session_key: Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]
) -> Session:
def from_session_key(cls, session_key: Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]) -> Session:
"""
Create a Session instance from a session key tuple.
:param session_key: Tuple containing the session details.
:return: A Session instance.
"""
protocol, src_ip_address, dst_ip_address, src_port, dst_port = session_key
protocol, with_ip_address, src_port, dst_port = session_key
return Session(
protocol=protocol,
src_ip_address=src_ip_address,
dst_ip_address=dst_ip_address,
with_ip_address=with_ip_address,
src_port=src_port,
dst_port=dst_port,
)
@@ -97,8 +95,8 @@ class SessionManager:
@staticmethod
def _get_session_key(
frame: Frame, from_source: bool = True
) -> Tuple[IPProtocol, IPv4Address, IPv4Address, Optional[Port], Optional[Port]]:
frame: Frame, inbound_frame: bool = True
) -> Tuple[IPProtocol, IPv4Address, Optional[Port], Optional[Port]]:
"""
Extracts the session key from the given frame.
@@ -110,32 +108,39 @@ class SessionManager:
- Optional[Port]: The destination port number (if applicable).
:param frame: The network frame from which to extract the session key.
:param from_source: A flag to indicate if the key should be extracted from the source or destination.
:return: A tuple containing the session key.
"""
protocol = frame.ip.protocol
src_ip_address = frame.ip.src_ip_address
dst_ip_address = frame.ip.dst_ip_address
with_ip_address = frame.ip.src_ip_address
if protocol == IPProtocol.TCP:
if from_source:
if inbound_frame:
src_port = frame.tcp.src_port
dst_port = frame.tcp.dst_port
else:
dst_port = frame.tcp.src_port
src_port = frame.tcp.dst_port
with_ip_address = frame.ip.dst_ip_address
elif protocol == IPProtocol.UDP:
if from_source:
if inbound_frame:
src_port = frame.udp.src_port
dst_port = frame.udp.dst_port
else:
dst_port = frame.udp.src_port
src_port = frame.udp.dst_port
with_ip_address = frame.ip.dst_ip_address
else:
src_port = None
dst_port = None
return protocol, src_ip_address, dst_ip_address, src_port, dst_port
return protocol, with_ip_address, src_port, dst_port
def receive_payload_from_software_manager(self, payload: Any, session_id: Optional[int] = None):
def receive_payload_from_software_manager(
self,
payload: Any,
dst_ip_address: Optional[IPv4Address] = None,
dst_port: Optional[Port] = None,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> Union[Any, None]:
"""
Receive a payload from the SoftwareManager.
@@ -144,46 +149,87 @@ class SessionManager:
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional. If None, one will be created.
"""
# TODO: Implement session creation and
if session_id:
session = self.sessions_by_uuid[session_id]
dst_ip_address = self.sessions_by_uuid[session_id].with_ip_address
dst_port = self.sessions_by_uuid[session_id].dst_port
self.send_payload_to_nic(payload, session_id)
dst_mac_address = self.arp_cache.get_arp_cache_mac_address(dst_ip_address)
def send_payload_to_software_manager(self, payload: Any, session_id: int):
if dst_mac_address:
outbound_nic = self.arp_cache.get_arp_cache_nic(dst_ip_address)
else:
if not is_reattempt:
self.arp_cache.send_arp_request(dst_ip_address)
return self.receive_payload_from_software_manager(
payload=payload,
dst_ip_address=dst_ip_address,
dst_port=dst_port,
session_id=session_id,
is_reattempt=True,
)
else:
return
frame = Frame(
ethernet=EthernetHeader(src_mac_addr=outbound_nic.mac_address, dst_mac_addr=dst_mac_address),
ip=IPPacket(
src_ip_address=outbound_nic.ip_address,
dst_ip_address=dst_ip_address,
),
tcp=TCPHeader(
src_port=dst_port,
dst_port=dst_port,
),
payload=payload,
)
if not session_id:
session_key = self._get_session_key(frame, inbound_frame=False)
session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
session = Session.from_session_key(session_key)
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
outbound_nic.send_frame(frame)
def receive_frame(self, frame: Frame):
"""
Send a payload to the software manager.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload originates from.
"""
self.software_manager.receive_payload_from_session_manger()
def send_payload_to_nic(self, payload: Any, session_id: int):
"""
Send a payload across the Network.
Takes a payload and a session_id. Builds a Frame and sends it across the network via a NIC.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload originates from
"""
# TODO: Implement frame construction and sent to NIC.
pass
def receive_payload_from_nic(self, frame: Frame):
"""
Receive a Frame from the NIC.
Receive a Frame.
Extract the session key using the _get_session_key method, and forward the payload to the appropriate
session. If the session does not exist, a new one is created.
:param frame: The frame being received.
"""
session_key = self._get_session_key(frame)
session = self.sessions_by_key.get(session_key)
session_key = self._get_session_key(frame, inbound_frame=True)
session: Session = self.sessions_by_key.get(session_key)
if not session:
# Create new session
session = Session.from_session_key(session_key)
self.sessions_by_key[session_key] = session
self.sessions_by_uuid[session.uuid] = session
self.software_manager.receive_payload_from_session_manger(payload=frame, session=session)
# TODO: Implement the frame deconstruction and send to SoftwareManager.
self.software_manager.receive_payload_from_session_manager(
payload=frame.payload, port=frame.tcp.dst_port, protocol=frame.ip.protocol, session_id=session.uuid
)
def show(self, markdown: bool = False):
"""
Print tables describing the SessionManager.
Generate and print PrettyTable instances that show details about
session's destination IP Address, destination Ports and the protocol to use.
Output can be in Markdown format.
:param markdown: Use Markdown style in table output. Defaults to False.
"""
table = PrettyTable(["Destination IP", "Port", "Protocol"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} Session Manager"
for session in self.sessions_by_key.values():
table.add_row([session.dst_ip_address, session.dst_port.value, session.protocol.name])
print(table)

View File

@@ -1,99 +1,162 @@
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import SoftwareType
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import IOSoftware
if TYPE_CHECKING:
from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.sys_log import SysLog
from typing import Type, TypeVar
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
class SoftwareManager:
"""A class that manages all running Services and Applications on a Node and facilitates their communication."""
def __init__(self, session_manager: "SessionManager", sys_log: "SysLog"):
def __init__(
self,
session_manager: "SessionManager",
sys_log: SysLog,
file_system: FileSystem,
dns_server: Optional[IPv4Address],
):
"""
Initialize a new instance of SoftwareManager.
:param session_manager: The session manager handling network communications.
"""
self.session_manager = session_manager
self.services: Dict[str, Service] = {}
self.applications: Dict[str, Application] = {}
self.software: Dict[str, Union[Service, Application]] = {}
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
self.dns_server: Optional[IPv4Address] = dns_server
def add_service(self, name: str, service: Service, port: Port, protocol: IPProtocol):
def get_open_ports(self) -> List[Port]:
"""
Add a Service to the manager.
Get a list of open ports.
:param name: The name of the service.
:param service: The service instance.
:param port: The port used by the service.
:param protocol: The network protocol used by the service.
:return: A list of all open ports on the Node.
"""
service.software_manager = self
self.services[name] = service
self.port_protocol_mapping[(port, protocol)] = service
open_ports = [Port.ARP]
for software in self.port_protocol_mapping.values():
if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}:
open_ports.append(software.port)
open_ports.sort(key=lambda port: port.value)
return open_ports
def add_application(self, name: str, application: Application, port: Port, protocol: IPProtocol):
def install(self, software_class: Type[IOSoftwareClass]):
"""
Add an Application to the manager.
Install an Application or Service.
:param name: The name of the application.
:param application: The application instance.
:param port: The port used by the application.
:param protocol: The network protocol used by the application.
:param software_class: The software class.
"""
application.software_manager = self
self.applications[name] = application
self.port_protocol_mapping[(port, protocol)] = application
if software_class in self._software_class_to_name_map:
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
)
if isinstance(software, Application):
software.install()
software.software_manager = self
self.software[software.name] = software
self.port_protocol_mapping[(software.port, software.protocol)] = software
self.sys_log.info(f"Installed {software.name}")
if isinstance(software, Application):
software.operating_state = ApplicationOperatingState.CLOSED
def send_internal_payload(self, target_software: str, target_software_type: SoftwareType, payload: Any):
def uninstall(self, software_name: str):
"""
Uninstall an Application or Service.
:param software_name: The software name.
"""
if software_name in self.software:
software = self.software.pop(software_name) # noqa
del software
self.sys_log.info(f"Deleted {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
def send_internal_payload(self, target_software: str, payload: Any):
"""
Send a payload to a specific service or application.
:param target_software: The name of the target service or application.
:param target_software_type: The type of software (Service, Application, Process).
:param payload: The data to be sent.
:param receiver_type: The type of the target, either 'service' or 'application'.
"""
if target_software_type is SoftwareType.SERVICE:
receiver = self.services.get(target_software)
elif target_software_type is SoftwareType.APPLICATION:
receiver = self.applications.get(target_software)
else:
raise ValueError(f"Invalid receiver type {target_software_type}")
receiver = self.software.get(target_software)
if receiver:
receiver.receive_payload(payload)
else:
raise ValueError(f"No {target_software_type.name.lower()} found with the name {target_software}")
self.sys_log.error(f"No Service of Application found with the name {target_software}")
def send_payload_to_session_manger(self, payload: Any, session_id: Optional[int] = None):
def send_payload_to_session_manager(
self,
payload: Any,
dest_ip_address: Optional[IPv4Address] = None,
dest_port: Optional[Port] = None,
session_id: Optional[str] = None,
):
"""
Send a payload to the SessionManager.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
"""
self.session_manager.receive_payload_from_software_manager(payload, session_id)
self.session_manager.receive_payload_from_software_manager(
payload=payload, dst_ip_address=dest_ip_address, dst_port=dest_port, session_id=session_id
)
def receive_payload_from_session_manger(self, payload: Any, session: Session):
def receive_payload_from_session_manager(self, payload: Any, port: Port, protocol: IPProtocol, session_id: str):
"""
Receive a payload from the SessionManager and forward it to the corresponding service or application.
:param payload: The payload being received.
:param session: The transport session the payload originates from.
"""
# receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
# if receiver:
# receiver.receive_payload(None, payload)
# else:
# raise ValueError(f"No service or application found for port {port} and protocol {protocol}")
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
if receiver:
receiver.receive(payload=payload, session_id=session_id)
else:
self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}")
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.
:param markdown: If True, outputs the table in markdown format. Default is False.
"""
table = PrettyTable(["Name", "Type", "Operating State", "Health State", "Port"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} Software Manager"
for software in self.port_protocol_mapping.values():
software_type = "Service" if isinstance(software, Service) else "Application"
table.add_row(
[
software.name,
software_type,
software.operating_state.name,
software.health_state_actual.name,
software.port.value,
]
)
print(table)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator import TEMP_SIM_OUTPUT
from primaite.simulator import SIM_OUTPUT
class _NotJSONFilter(logging.Filter):
@@ -81,7 +81,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = TEMP_SIM_OUTPUT / self.hostname
root = SIM_OUTPUT / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -1,76 +0,0 @@
from typing import Dict
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.system.services.service import Service
class DatabaseService(Service):
"""Service loosely modelled on Microsoft SQL Server."""
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def uninstall(self) -> None:
"""
Undo installation procedure.
This method deletes files created when installing the database, and the database folder if it is empty.
"""
super().uninstall()
node: Node = self.parent
node.file_system.delete_file(self.primary_store)
node.file_system.delete_file(self.transaction_log)
if self.secondary_store:
node.file_system.delete_file(self.secondary_store)
if len(self.folder.files) == 0:
node.file_system.delete_folder(self.folder)
def install(self) -> None:
"""Perform first time install on a node, creating necessary files."""
super().install()
assert isinstance(self.parent, Node), "Database install can only happen after the db service is added to a node"
self._setup_files()
def _setup_files(
self,
db_size: int = 1000,
use_secondary_db_file: bool = False,
secondary_db_size: int = 300,
folder_name: str = "database",
):
"""Set up files that are required by the database on the parent host.
:param db_size: Initial file size of the main database file, defaults to 1000
:type db_size: int, optional
:param use_secondary_db_file: Whether to use a secondary database file, defaults to False
:type use_secondary_db_file: bool, optional
:param secondary_db_size: Size of the secondary db file, defaults to None
:type secondary_db_size: int, optional
:param folder_name: Name of the folder which will be setup to hold the db files, defaults to "database"
:type folder_name: str, optional
"""
# note that this parent.file_system.create_folder call in the future will be authenticated by using permissions
# handler. This permission will be granted based on service account given to the database service.
self.parent: Node
self.folder = self.parent.file_system.create_folder(folder_name)
self.primary_store = self.parent.file_system.create_file(
"db_primary_store", db_size, FileSystemFileType.MDF, folder=self.folder
)
self.transaction_log = self.parent.file_system.create_file(
"db_transaction_log", "1", FileSystemFileType.LDF, folder=self.folder
)
if use_secondary_db_file:
self.secondary_store = self.parent.file_system.create_file(
"db_secondary_store", secondary_db_size, FileSystemFileType.NDF, folder=self.folder
)
else:
self.secondary_store = None

View File

@@ -0,0 +1,155 @@
import sqlite3
from datetime import datetime
from sqlite3 import OperationalError
from typing import Any, Dict, List, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
class DatabaseService(Service):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
"""
password: Optional[str] = None
connections: Dict[str, datetime] = {}
def __init__(self, **kwargs):
kwargs["name"] = "DatabaseService"
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
self._conn = sqlite3.connect(self._db_file.sim_path)
self._cursor = self._conn.cursor()
def tables(self) -> List[str]:
"""
Get a list of table names present in the database.
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql)
return [row[0] for row in results["data"]]
def show(self, markdown: bool = False):
"""
Prints a list of table names in the database using PrettyTable.
:param markdown: Whether to output the table in Markdown format.
"""
table = PrettyTable(["Table"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.file_system.sys_log.hostname} Database"
for row in self.tables():
table.add_row([row])
print(table)
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
self.folder = self._db_file.folder
def _process_connect(
self, session_id: str, password: Optional[str] = None
) -> Dict[str, Union[int, Dict[str, bool]]]:
status_code = 500 # Default internal server error
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.GOOD:
if self.password == password:
status_code = 200 # ok
self.connections[session_id] = datetime.now()
self.sys_log.info(f"Connect request for {session_id=} authorised")
else:
status_code = 401 # Unauthorised
self.sys_log.info(f"Connect request for {session_id=} declined")
else:
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
try:
self._cursor.execute(query)
self._conn.commit()
except OperationalError:
# Handle the case where the table does not exist.
self.sys_log.error(f"{self.name}: Error, query failed")
return {"status_code": 404, "data": {}}
data = []
description = self._cursor.description
if description:
headers = []
for header in description:
headers.append(header[0])
data = self._cursor.fetchall()
if data and headers:
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return super().describe_state()
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Processes the incoming SQL payload and sends the result back.
:param payload: The SQL query to be executed.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
result = {"status_code": 500, "data": []}
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":
result = self._process_connect(session_id=session_id, password=payload.get("password"))
elif payload["type"] == "disconnect":
if session_id in self.connections:
self.connections.pop(session_id)
elif payload["type"] == "sql":
if session_id in self.connections:
result = self._process_sql(query=payload["sql"], query_id=payload["uuid"])
else:
result = {"status_code": 401, "type": "sql"}
self.send(payload=result, session_id=session_id)
return True
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Send a SQL response back down to the SessionManager.
:param payload: The SQL query results.
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return payload["status_code"] == 200

View File

@@ -0,0 +1,154 @@
from ipaddress import IPv4Address
from typing import Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class DNSClient(Service):
"""Represents a DNS Client as a Service."""
dns_cache: Dict[str, IPv4Address] = {}
"A dict of known mappings between domain/URLs names and IPv4 addresses."
dns_server: Optional[IPv4Address] = None
"The DNS Server the client sends requests to."
def __init__(self, **kwargs):
kwargs["name"] = "DNSClient"
kwargs["port"] = Port.DNS
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
# TCP for now
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
"""
Adds a domain name to the DNS Client cache.
:param: domain_name: The domain name to save to cache
:param: ip_address: The IP Address to attach the domain name to
"""
self.dns_cache[domain_name] = ip_address
def check_domain_exists(
self,
target_domain: str,
session_id: Optional[str] = None,
is_reattempt: bool = False,
) -> bool:
"""Function to check if domain name exists.
:param: target_domain: The domain requested for an IP address.
:param: session_id: The Session ID the payload is to originate from. Optional.
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
"""
# check if the target domain is in the client's DNS cache
payload = DNSPacket(dns_request=DNSRequest(domain_name_request=target_domain))
# check if the domain is already in the DNS cache
if target_domain in self.dns_cache:
self.sys_log.info(
f"DNS Client: Domain lookup for {target_domain} successful, resolves to {self.dns_cache[target_domain]}"
)
return True
else:
# return False if already reattempted
if is_reattempt:
self.sys_log.info(f"DNS Client: Domain lookup for {target_domain} failed")
return False
else:
# send a request to check if domain name exists in the DNS Server
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=self.dns_server, dest_port=Port.DNS
)
# recursively re-call the function passing is_reattempt=True
return self.check_domain_exists(
target_domain=target_domain,
session_id=session_id,
is_reattempt=True,
)
def send(
self,
payload: DNSPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param dest_ip_address: The ip address of the payload destination.
:param dest_port: The port of the payload destination.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
# create DNS request packet
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
return True
def receive(
self,
payload: DNSPacket,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to be sent.
:param session_id: The Session ID the payload is to originate from. Optional.
:return: True if successful, False otherwise.
"""
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_reply is not None:
# add the IP address to the client cache
if payload.dns_reply.domain_name_ip_address:
self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address
return True
return False

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite import getLogger
from primaite.simulator.network.protocols.dns import DNSPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class DNSServer(Service):
"""Represents a DNS Server as a Service."""
dns_table: Dict[str, IPv4Address] = {}
"A dict of mappings between domain names and IPv4 addresses."
def __init__(self, **kwargs):
kwargs["name"] = "DNSServer"
kwargs["port"] = Port.DNS
# DNS uses UDP by default
# it switches to TCP when the bytes exceed 512 (or 4096) bytes
# TCP for now
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
The specifics of the software's state, including its health, criticality,
and any other pertinent information, should be implemented in subclasses.
:return: A dictionary containing key-value pairs representing the current state of the software.
:rtype: Dict
"""
state = super().describe_state()
return state
def dns_lookup(self, target_domain: str) -> Optional[IPv4Address]:
"""
Attempts to find the IP address for a domain name.
:param target_domain: The single domain name requested by a DNS client.
:return ip_address: The IP address of that domain name or None.
"""
return self.dns_table.get(target_domain)
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
"""
Register a domain name and its IP address.
:param: domain_name: The domain name to register
:type: domain_name: str
:param: domain_ip_address: The IP address that the domain should route to
:type: domain_ip_address: IPv4Address
"""
self.dns_table[domain_name] = domain_ip_address
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param: payload: The payload to send.
:param: session_id: The id of the session. Optional.
:return: True if DNS request returns a valid IP, otherwise, False
"""
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_request is not None:
self.sys_log.info(
f"DNS Server: Received domain lookup request for {payload.dns_request.domain_name_request} "
f"from session {session_id}"
)
# generate a reply with the correct DNS IP address
payload = payload.generate_reply(self.dns_lookup(payload.dns_request.domain_name_request))
self.sys_log.info(
f"DNS Server: Responding to domain lookup request for {payload.dns_request.domain_name_request} "
f"with ip address: {payload.dns_reply.domain_name_ip_address}"
)
# send reply
self.send(payload, session_id)
return payload.dns_reply.domain_name_ip_address is not None
return False
def show(self, markdown: bool = False):
"""Prints a table of DNS Lookup table."""
table = PrettyTable(["Domain Name", "IP Address"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} DNS Lookup table"
for dns in self.dns_table.items():
table.add_row([dns[0], dns[1]])
print(table)

View File

@@ -0,0 +1,49 @@
from ipaddress import IPv4Address
from typing import Optional
from primaite.simulator.system.applications.database_client import DatabaseClient
class DataManipulationBot(DatabaseClient):
"""
Red Agent Data Integration Service.
The Service represents a bot that causes files/folders in the File System to
become corrupted.
"""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def configure(
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The data manipulation query payload.
"""
self.server_ip_address = server_ip_address
self.payload = payload
self.server_password = server_password
self.sys_log.info(f"Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}.")
def run(self):
"""Run the DataManipulationBot."""
if self.server_ip_address and self.payload:
self.sys_log.info(f"Attempting to start the {self.name}")
super().run()
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
else:
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_io_address and payload.")

View File

@@ -1,4 +1,3 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
@@ -33,7 +32,7 @@ class Service(IOSoftware):
Services are programs that run in the background and may perform input/output operations.
"""
operating_state: ServiceOperatingState
operating_state: ServiceOperatingState = ServiceOperatingState.STOPPED
"The current operating state of the Service."
restart_duration: int = 5
"How many timesteps does it take to restart this service."
@@ -51,7 +50,6 @@ class Service(IOSoftware):
am.add_action("enable", Action(func=lambda request, context: self.enable()))
return am
@abstractmethod
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -74,77 +72,85 @@ class Service(IOSoftware):
"""
pass
def send(self, payload: Any, session_id: str, **kwargs) -> bool:
def send(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Sends a payload to the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to send.
:param: payload: The payload to send.
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
pass
self.software_manager.send_payload_to_session_manager(payload=payload, session_id=session_id)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
def receive(
self,
payload: Any,
session_id: Optional[str] = None,
**kwargs,
) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:param: payload: The payload to send.
:param: session_id: The id of the session
:return: True if successful, False otherwise.
"""
pass
pass
def stop(self) -> None:
"""Stop the service."""
_LOGGER.debug(f"Stopping service {self.name}")
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.parent.sys_log.info(f"Stopping service {self.name}")
self.sys_log.info(f"Stopping service {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
def start(self) -> None:
def start(self, **kwargs) -> None:
"""Start the service."""
_LOGGER.debug(f"Starting service {self.name}")
if self.operating_state == ServiceOperatingState.STOPPED:
self.parent.sys_log.info(f"Starting service {self.name}")
self.sys_log.info(f"Starting service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
def pause(self) -> None:
"""Pause the service."""
_LOGGER.debug(f"Pausing service {self.name}")
if self.operating_state == ServiceOperatingState.RUNNING:
self.parent.sys_log.info(f"Pausing service {self.name}")
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.PAUSED
def resume(self) -> None:
"""Resume paused service."""
_LOGGER.debug(f"Resuming service {self.name}")
if self.operating_state == ServiceOperatingState.PAUSED:
self.parent.sys_log.info(f"Resuming service {self.name}")
self.sys_log.info(f"Resuming service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
def restart(self) -> None:
"""Restart running service."""
_LOGGER.debug(f"Restarting service {self.name}")
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
self.parent.sys_log.info(f"Pausing service {self.name}")
self.sys_log.info(f"Pausing service {self.name}")
self.operating_state = ServiceOperatingState.RESTARTING
self.restart_countdown = self.restarting_duration
def disable(self) -> None:
"""Disable the service."""
_LOGGER.debug(f"Disabling service {self.name}")
self.parent.sys_log.info(f"Disabling Application {self.name}")
self.sys_log.info(f"Disabling Application {self.name}")
self.operating_state = ServiceOperatingState.DISABLED
def enable(self) -> None:
"""Enable the disabled service."""
_LOGGER.debug(f"Enabling service {self.name}")
if self.operating_state == ServiceOperatingState.DISABLED:
self.parent.sys_log.info(f"Enabling Application {self.name}")
self.sys_log.info(f"Enabling Application {self.name}")
self.operating_state = ServiceOperatingState.STOPPED
def apply_timestep(self, timestep: int) -> None:

View File

@@ -1,9 +1,11 @@
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Set
from typing import Any, Dict, Optional
from primaite.simulator.core import Action, ActionManager, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.sys_log import SysLog
class SoftwareType(Enum):
@@ -62,11 +64,11 @@ class Software(SimComponent):
name: str
"The name of the software."
health_state_actual: SoftwareHealthState
health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD
"The actual health state of the software."
health_state_visible: SoftwareHealthState
health_state_visible: SoftwareHealthState = SoftwareHealthState.GOOD
"The health state of the software visible to the red agent."
criticality: SoftwareCriticality
criticality: SoftwareCriticality = SoftwareCriticality.LOWEST
"The criticality level of the software."
patching_count: int = 0
"The count of patches applied to the software, defaults to 0."
@@ -74,6 +76,14 @@ class Software(SimComponent):
"The count of times the software has been scanned, defaults to 0."
revealed_to_red: bool = False
"Indicates if the software has been revealed to red agent, defaults is False."
software_manager: Any = None
"An instance of Software Manager that is used by the parent node."
sys_log: SysLog = None
"An instance of SysLog that is used by the parent node."
file_system: FileSystem
"The FileSystem of the Node the Software is installed on."
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
@@ -132,7 +142,6 @@ class Software(SimComponent):
"""
self.health_state_actual = health_state
@abstractmethod
def install(self) -> None:
"""
Perform first-time setup of this service on a node.
@@ -175,8 +184,8 @@ class IOSoftware(Software):
"Indicates if the software uses TCP protocol for communication. Default is True."
udp: bool = True
"Indicates if the software uses UDP protocol for communication. Default is True."
ports: Set[Port]
"The set of ports to which the software is connected."
port: Port
"The port to which the software is connected."
@abstractmethod
def describe_state(self) -> Dict:
@@ -212,7 +221,6 @@ class IOSoftware(Software):
:param kwargs: Additional keyword arguments specific to the implementation.
:return: True if the payload was successfully sent, False otherwise.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""

View File

@@ -12,6 +12,8 @@ import pytest
from primaite import getLogger
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.simulator.network.container import Network
from primaite.simulator.network.networks import arcd_uc2_network
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
ACTION_SPACE_NODE_VALUES = 1
@@ -19,7 +21,22 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
# PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node
@pytest.fixture(scope="function")
def uc2_network() -> Network:
return arcd_uc2_network()
@pytest.fixture(scope="function")
def file_system() -> FileSystem:
return Node(hostname="fs_node").file_system
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.

View File

View File

@@ -0,0 +1,25 @@
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def test_data_manipulation(uc2_network):
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
database_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = database_server.software_manager.software["DatabaseService"]
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
# First check that the DB client on the web_server can successfully query the users table on the database
assert db_client.query("SELECT * FROM user;")
# Now we run the DataManipulationBot
db_manipulation_bot.run()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT * FROM user;")

View File

@@ -1,56 +1,59 @@
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.database import DatabaseService
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database_service import DatabaseService
def test_installing_database():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
ports=[
Port.SQL_SERVER,
],
operating_state=ServiceOperatingState.RUNNING,
)
def test_database_client_server_connection(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
node = Node(hostname="db-server")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
node.install_service(db)
assert len(db_service.connections) == 1
assert db in node
file_exists = False
for folder in node.file_system.folders.values():
for file in folder.files.values():
if file.name == "db_primary_store":
file_exists = True
break
if file_exists:
break
assert file_exists
db_client.disconnect()
assert len(db_service.connections) == 0
def test_uninstalling_database():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
ports=[
Port.SQL_SERVER,
],
operating_state=ServiceOperatingState.RUNNING,
)
def test_database_client_server_correct_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
node = Node(hostname="db-server")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
node.install_service(db)
db_client.disconnect()
node.uninstall_service(db)
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="12345")
db_service.password = "12345"
assert db not in node
assert node.file_system.get_folder_by_name("database") is None
assert db_client.connect()
assert len(db_service.connections) == 1
def test_database_client_server_incorrect_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
db_service.password = "12345"
assert not db_client.connect()
assert len(db_service.connections) == 0
def test_database_client_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query("SELECT * FROM user;")

View File

@@ -0,0 +1,28 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_dns_client_server(uc2_network):
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
assert dns_client.operating_state == ServiceOperatingState.RUNNING
assert dns_server.operating_state == ServiceOperatingState.RUNNING
dns_server.show()
# fake domain should not be added to dns cache
assert not dns_client.check_domain_exists(target_domain="fake-domain.com")
assert dns_client.dns_cache.get("fake-domain.com", None) is None
# arcd.com is registered in dns server and should be saved to cache
assert dns_client.check_domain_exists(target_domain="arcd.com")
assert dns_client.dns_cache.get("arcd.com", None) is not None

View File

@@ -45,7 +45,6 @@ def test_seeded_learning(temp_primaite_session):
), "Expected output is based upon a agent that was trained with seed 67890"
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode_dict()
print(actual_mean_reward_per_episode, "THISt")
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode

View File

@@ -1,132 +1,144 @@
import pytest
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_folder import FileSystemFolder
from primaite.simulator.file_system.file_type import FileType
def test_create_folder_and_file():
def test_create_folder_and_file(file_system):
"""Test creating a folder and a file."""
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) is 1
assert len(file_system.folders) == 1
file_system.create_folder(folder_name="test_folder")
file = file_system.create_file(file_name="test_file", size=10, folder_uuid=folder.uuid)
assert len(file_system.get_folder_by_id(folder.uuid).files) is 1
assert len(file_system.folders) is 2
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
assert file_system.get_file_by_id(file.uuid).name is "test_file"
assert file_system.get_file_by_id(file.uuid).size == 10
assert len(file_system.get_folder("test_folder").files) == 1
assert file_system.get_folder("test_folder").get_file("test_file.txt")
def test_create_file():
def test_create_file_no_folder(file_system):
"""Tests that creating a file without a folder creates a folder and sets that as the file's parent."""
file_system = FileSystem()
file = file_system.create_file(file_name="test_file", size=10)
file = file_system.create_file(file_name="test_file.txt", size=10)
assert len(file_system.folders) is 1
assert file_system.get_folder_by_name("root").get_file_by_id(file.uuid) is file
assert file_system.get_folder("root").get_file("test_file.txt") == file
assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT
assert file_system.get_folder("root").get_file("test_file.txt").size == 10
def test_delete_file():
def test_create_file_no_extension(file_system):
"""Tests that creating a file without an extension sets the file type to FileType.UNKNOWN."""
file = file_system.create_file(file_name="test_file")
assert len(file_system.folders) is 1
assert file_system.get_folder("root").get_file("test_file") == file
assert file_system.get_folder("root").get_file("test_file").file_type == FileType.UNKNOWN
assert file_system.get_folder("root").get_file("test_file").size == 0
def test_delete_file(file_system):
"""Tests that a file can be deleted."""
file_system = FileSystem()
file_system.create_file(file_name="test_file.txt")
assert len(file_system.folders) == 1
assert len(file_system.get_folder("root").files) == 1
file = file_system.create_file(file_name="test_file", size=10)
assert len(file_system.folders) is 1
folder_id = list(file_system.folders.keys())[0]
folder = file_system.get_folder_by_id(folder_id)
assert folder.get_file_by_id(file.uuid) is file
file_system.delete_file(file=file)
assert len(file_system.folders) is 1
assert len(folder.files) is 0
file_system.delete_file(folder_name="root", file_name="test_file.txt")
assert len(file_system.folders) == 1
assert len(file_system.get_folder("root").files) == 0
def test_delete_non_existent_file():
def test_delete_non_existent_file(file_system):
"""Tests deleting a non existent file."""
file_system = FileSystem()
file = file_system.create_file(file_name="test_file", size=10)
not_added_file = FileSystemFile(name="not_added")
file_system.create_file(file_name="test_file.txt")
# folder should be created
assert len(file_system.folders) is 1
assert len(file_system.folders) == 1
# should only have 1 file in the file system
folder_id = list(file_system.folders.keys())[0]
folder = file_system.get_folder_by_id(folder_id)
assert len(list(folder.files)) is 1
assert folder.get_file_by_id(file.uuid) is file
assert len(file_system.get_folder("root").files) == 1
# deleting should not change how many files are in folder
file_system.delete_file(file=not_added_file)
assert len(file_system.folders) is 1
assert len(list(folder.files)) is 1
file_system.delete_file(folder_name="root", file_name="does_not_exist!")
# should still only be one folder
assert len(file_system.folders) == 1
# The folder should still have 1 file
assert len(file_system.get_folder("root").files) == 1
def test_delete_folder():
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) is 1
def test_delete_folder(file_system):
file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) == 2
file_system.delete_folder(folder)
assert len(file_system.folders) is 0
file_system.delete_folder(folder_name="test_folder")
assert len(file_system.folders) == 1
def test_deleting_a_non_existent_folder():
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
not_added_folder = FileSystemFolder(name="fake_folder")
assert len(file_system.folders) is 1
def test_deleting_a_non_existent_folder(file_system):
file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) == 2
file_system.delete_folder(not_added_folder)
assert len(file_system.folders) is 1
file_system.delete_folder(folder_name="does not exist!")
assert len(file_system.folders) == 2
def test_move_file():
def test_deleting_root_folder_fails(file_system):
assert len(file_system.folders) == 1
file_system.delete_folder(folder_name="root")
assert len(file_system.folders) == 1
def test_move_file(file_system):
"""Tests the file move function."""
file_system = FileSystem()
src_folder = file_system.create_folder(folder_name="test_folder_1")
assert len(file_system.folders) is 1
file_system.create_folder(folder_name="src_folder")
file_system.create_folder(folder_name="dst_folder")
target_folder = file_system.create_folder(folder_name="test_folder_2")
assert len(file_system.folders) is 2
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder")
original_uuid = file.uuid
file = file_system.create_file(file_name="test_file", size=10, folder_uuid=src_folder.uuid)
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 0
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 0
file_system.move_file(file=file, src_folder=src_folder, target_folder=target_folder)
file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 0
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 1
assert len(file_system.get_folder("src_folder").files) == 0
assert len(file_system.get_folder("dst_folder").files) == 1
assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid
def test_copy_file():
def test_copy_file(file_system):
"""Tests the file copy function."""
file_system = FileSystem()
src_folder = file_system.create_folder(folder_name="test_folder_1")
assert len(file_system.folders) is 1
file_system.create_folder(folder_name="src_folder")
file_system.create_folder(folder_name="dst_folder")
target_folder = file_system.create_folder(folder_name="test_folder_2")
assert len(file_system.folders) is 2
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True)
original_uuid = file.uuid
file = file_system.create_file(file_name="test_file", size=10, folder_uuid=src_folder.uuid)
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 0
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 0
file_system.copy_file(file=file, src_folder=src_folder, target_folder=target_folder)
file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
assert len(file_system.get_folder_by_id(src_folder.uuid).files) is 1
assert len(file_system.get_folder_by_id(target_folder.uuid).files) is 1
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 1
assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid
def test_serialisation():
def test_folder_quarantine_state(file_system):
"""Tests the changing of folder quarantine status."""
folder = file_system.get_folder("root")
assert folder.quarantine_status() is False
folder.quarantine()
assert folder.quarantine_status() is True
folder.unquarantine()
assert folder.quarantine_status() is False
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_serialisation(file_system):
"""Test to check that the object serialisation works correctly."""
file_system = FileSystem()
folder = file_system.create_folder(folder_name="test_folder")
assert len(file_system.folders) is 1
file_system.create_file(file_name="test_file", size=10, folder_uuid=folder.uuid)
assert file_system.get_folder_by_id(folder.uuid) is folder
file_system.create_file(file_name="test_file.txt")
serialised_file_sys = file_system.model_dump_json()
deserialised_file_sys = FileSystem.model_validate_json(serialised_file_sys)

View File

@@ -1,23 +0,0 @@
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
def test_file_type():
"""Tests tha the FileSystemFile type is set correctly."""
file = FileSystemFile(name="test", file_type=FileSystemFileType.DOC)
assert file.file_type is FileSystemFileType.DOC
def test_get_size():
"""Tests that the file size is being returned properly."""
file = FileSystemFile(name="test", size=1.5)
assert file.size == 1.5
def test_serialisation():
"""Test to check that the object serialisation works correctly."""
file = FileSystemFile(name="test", size=1.5, file_type=FileSystemFileType.DOC)
serialised_file = file.model_dump_json()
deserialised_file = FileSystemFile.model_validate_json(serialised_file)
assert file.model_dump_json() == deserialised_file.model_dump_json()

View File

@@ -1,75 +0,0 @@
from primaite.simulator.file_system.file_system_file import FileSystemFile
from primaite.simulator.file_system.file_system_file_type import FileSystemFileType
from primaite.simulator.file_system.file_system_folder import FileSystemFolder
def test_adding_removing_file():
"""Test the adding and removing of a file from a folder."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
assert folder.size == 10
assert len(folder.files) is 1
folder.remove_file(file)
assert folder.size == 0
assert len(folder.files) is 0
def test_remove_non_existent_file():
"""Test the removing of a file that does not exist."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
not_added_file = FileSystemFile(name="fake_file", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
assert folder.size == 10
assert len(folder.files) is 1
folder.remove_file(not_added_file)
assert folder.size == 10
assert len(folder.files) is 1
def test_get_file_by_id():
"""Test to make sure that the correct file is returned."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
file2 = FileSystemFile(name="test_file_2", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
folder.add_file(file2)
assert folder.size == 20
assert len(folder.files) is 2
assert folder.get_file_by_id(file_id=file.uuid) is file
def test_folder_quarantine_state():
"""Tests the changing of folder quarantine status."""
folder = FileSystemFolder(name="test")
assert folder.quarantine_status() is False
folder.quarantine()
assert folder.quarantine_status() is True
folder.end_quarantine()
assert folder.quarantine_status() is False
def test_serialisation():
"""Test to check that the object serialisation works correctly."""
folder = FileSystemFolder(name="test")
file = FileSystemFile(name="test_file", size=10, file_type=FileSystemFileType.DOC)
folder.add_file(file)
serialised_folder = folder.model_dump_json()
deserialised_folder = FileSystemFolder.model_validate_json(serialised_folder)
assert folder.model_dump_json() == deserialised_folder.model_dump_json()

View File

@@ -1,5 +1,7 @@
import json
import pytest
from primaite.simulator.network.container import Network
@@ -10,6 +12,7 @@ def test_creating_container():
assert net.links == {}
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_describe_state():
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable."""
net = Network()

View File

@@ -0,0 +1,20 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
def test_creation():
network = arcd_uc2_network()
client_1: Node = network.get_node_by_hostname("client_1")
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;"

View File

@@ -1,17 +1,18 @@
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.database import DatabaseService
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.software import SoftwareCriticality, SoftwareHealthState
import json
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.system.services.database_service import DatabaseService
def test_creation():
db = DatabaseService(
name="SQL-database",
health_state_actual=SoftwareHealthState.GOOD,
health_state_visible=SoftwareHealthState.GOOD,
criticality=SoftwareCriticality.MEDIUM,
ports=[
Port.SQL_SERVER,
],
operating_state=ServiceOperatingState.RUNNING,
)
@pytest.fixture(scope="function")
def database_server() -> Node:
node = Node(hostname="db_node")
node.software_manager.install(DatabaseService)
node.software_manager.software["DatabaseService"].start()
return node
def test_creation(database_server):
database_server.software_manager.show()

View File

@@ -0,0 +1,100 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns_client import DNSClient
from primaite.simulator.system.services.dns_server import DNSServer
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Node(hostname="dns_server")
node.software_manager.install(software_class=DNSServer)
node.software_manager.software["DNSServer"].start()
return node
@pytest.fixture(scope="function")
def dns_client() -> Node:
node = Node(hostname="dns_client")
node.software_manager.install(software_class=DNSClient)
node.software_manager.software["DNSClient"].start()
return node
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.TCP
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.TCP
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
# return none for an unknown domain
assert dns_server_service.dns_lookup("fake-domain.com") is None
assert dns_server_service.dns_lookup("real-domain.com") is not None
def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
# add a domain to the dns client cache
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
assert dns_client_service.check_domain_exists("fake-domain.com") is False
assert dns_client_service.check_domain_exists("real-domain.com") is True
def test_dns_server_receive(dns_server):
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
is False
)
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
is True
)
dns_server_service.show()
def test_dns_client_receive(dns_client):
"""Test to make sure the DNS Client knows how to deal with request responses."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service.receive(
payload=DNSPacket(
dns_request=DNSRequest(domain_name_request="real-domain.com"),
dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")),
)
)
# domain name should be saved to cache
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")