#2706 - Initial refactor of Terminal Class following review discussion on Friday. Terminal will now return a TerminalConnection/RemoteTerminalConnection object on login.

The new connection object can then be used to pass commands to the target node, without needing to form a full payload item.
This commit is contained in:
Charlie Crane
2024-08-05 09:29:17 +01:00
parent d9faa1a5da
commit 4bddf72cd3
4 changed files with 205 additions and 241 deletions

View File

@@ -26,7 +26,8 @@
"source": [
"from primaite.simulator.system.services.terminal.terminal import Terminal\n",
"from primaite.simulator.network.container import Network\n",
"from primaite.simulator.network.hardware.nodes.host.computer import Computer"
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript"
]
},
{
@@ -83,7 +84,38 @@
"outputs": [],
"source": [
"# Login to the remote (node_b) from local (node_a)\n",
"terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)"
"from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n",
"\n",
"\n",
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.software_manager.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(type(term_a_term_b_remote_connection))\n",
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.software_manager.show()"
]
},
{
@@ -109,45 +141,6 @@
"The Terminal can be used to send requests to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage\n",
"from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n",
"\n",
"computer_b.software_manager.show()\n",
"\n",
"payload: SSHPacket = SSHPacket(\n",
" payload=[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n",
" transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,\n",
" connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,\n",
" sender_ip_address=computer_a.network_interface[1].ip_address,\n",
" target_ip_address=computer_b.network_interface[1].ip_address,\n",
")\n",
"\n",
"# Send command to install RansomwareScript\n",
"terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `RansomwareScript` can then be seen in the list of applications on the `node_b Software Manager`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.software_manager.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -157,27 +150,6 @@
"Here, we send a command to `computer_b` to create a new folder titled \"Downloads\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.file_system.show()\n",
"\n",
"payload: SSHPacket = SSHPacket(\n",
" payload=[\"file_system\", \"create\", \"folder\", \"Downloads\"],\n",
" transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,\n",
" connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,\n",
" sender_ip_address=computer_a.network_interface[1].ip_address,\n",
" target_ip_address=computer_b.network_interface[1].ip_address,\n",
")\n",
"\n",
"terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)\n",
"\n",
"computer_b.file_system.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@@ -76,6 +76,8 @@ class SSHPacket(DataPacket):
user_account: Optional[SSHUserCredentials] = None
"""User Account Credentials if passed"""
connection_request_uuid: Optional[str] = None # Connection Request uuid.
connection_uuid: Optional[str] = None # The connection uuid used to validate the session
ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse

View File

@@ -10,18 +10,11 @@ from pydantic import BaseModel
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
SSHTransportMessage,
SSHUserCredentials,
)
from primaite.simulator.network.protocols.ssh import SSHPacket
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
class TerminalClientConnection(BaseModel):
@@ -31,40 +24,45 @@ class TerminalClientConnection(BaseModel):
This class is used to record current User Connections to the Terminal class.
"""
parent_node: Node # Technically should be HostNode but this causes circular import error.
parent_terminal: Terminal
"""The parent Node that this connection was created on."""
dest_ip_address: IPv4Address = None
"""Destination IP address of connection"""
session_id: str = None
"""Session ID that connection is linked to"""
_connection_uuid: str = None
connection_uuid: str = None
"""Connection UUID"""
@property
def client(self) -> Optional[Terminal]:
"""The Terminal that holds this connection."""
return self.parent_node.software_manager.software.get("Terminal")
return self.parent_terminal
def disconnect(self):
"""Disconnect the connection."""
if self.client:
self.client._disconnect(self._connection_uuid) # noqa
def disconnect(self) -> bool:
"""Disconnect the session."""
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
class RemoteTerminalConnection(TerminalClientConnection):
"""
RemoteTerminalConnection Class.
This class acts as broker between the terminal and remote.
"""
def execute(self, command: Any) -> bool:
"""Execute a given command on the remote Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
# Send command to remote terminal to process.
return self.parent_terminal.send(payload=command, session_id=self.session_id)
class Terminal(Service):
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING
"Initial Operating State"
health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD
"Service Health State"
_connections: Dict[str, TerminalClientConnection] = {}
"List of active connections held on this terminal."
_client_connection_requests: Dict[str, Optional[str]] = {}
def __init__(self, **kwargs):
kwargs["name"] = "Terminal"
@@ -155,34 +153,40 @@ class Terminal(Service):
return rm
def _add_new_connection(self, connection_uuid: str, session_id: str):
def execute(self, command: List[Any]) -> RequestResponse:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
def _create_local_connection(self, connection_uuid: str, session_id: str) -> RemoteTerminalConnection:
"""Create a new connection object and amend to list of active connections."""
self._connections[connection_uuid] = TerminalClientConnection(
parent_node=self.software_manager.node,
new_connection = TerminalClientConnection(
parent_terminal=self,
connection_uuid=connection_uuid,
session_id=session_id,
)
self._connections[connection_uuid] = new_connection
self._client_connection_requests[connection_uuid] = new_connection
def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool:
"""Process User request to login to Terminal.
return new_connection
If ip_address is passed, login will attempt a remote login to the node at that address.
:param username: The username credential.
:param password: The user password component of credentials.
:param dest_ip_address: The IP address of the node we want to connect to.
:return: True if successful, False otherwise.
"""
def login(
self, username: str, password: str, ip_address: Optional[IPv4Address] = None
) -> Optional[TerminalClientConnection]:
"""Login to the terminal. Will attempt a remote login if ip_address is given, else local."""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning("Cannot process login as service is not running")
return False
self.sys_log.warning("Cannot login as service is not running.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
if ip_address:
# if ip_address has been provided, we assume we are logging in to a remote terminal.
return self._send_remote_login(username=username, password=password, ip_address=ip_address)
# Assuming that if IP is passed we are connecting to remote
return self._send_remote_login(
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
)
else:
return self._process_local_login(username=username, password=password)
return self._process_local_login(username=username, password=password)
def _process_local_login(self, username: str, password: str) -> bool:
def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]:
"""Local session login to terminal.
:param username: Username for login.
@@ -195,110 +199,114 @@ class Terminal(Service):
if connection_uuid:
self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}")
# Add new local session to list of connections
session_id = str(uuid4())
self._add_new_connection(connection_uuid=connection_uuid, session_id=session_id)
return True
self._create_local_connection(connection_uuid=connection_uuid, session_id="")
return TerminalClientConnection(parent_terminal=self, session_id="", connection_uuid=connection_uuid)
else:
self.sys_log.warning("Login failed, incorrect Username or Password")
return False
return None
def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool:
"""Attempt to login to a remote terminal.
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
:param username: username for login.
:param password: password for login.
:ip_address: IP address of the target node for login.
"""
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
def _send_remote_login(
self,
username: str,
password: str,
ip_address: IPv4Address,
connection_request_id: str,
is_reattempt: bool = False,
) -> Optional[RemoteTerminalConnection]:
"""Process a remote login attempt."""
self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}")
if is_reattempt:
valid_connection = self._check_client_connection(connection_id=connection_request_id)
if valid_connection:
remote_terminal_connection = self._client_connection_requests.pop(connection_request_id)
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
return remote_terminal_connection
else:
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
return None
payload: SSHPacket = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
user_account=user_account,
payload = {
"type": "login_request",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=ip_address, dest_port=self.port
)
return self._send_remote_login(
username=username,
password=password,
ip_address=ip_address,
is_reattempt=True,
connection_request_id=connection_request_id,
)
self.sys_log.info(f"Sending remote login request to {ip_address}")
return self.send(payload=payload, dest_ip_address=ip_address)
def _create_remote_connection(self, connection_id: str, connection_request_id: str, session_id: str) -> None:
"""Create a new TerminalClientConnection Object."""
client_connection = RemoteTerminalConnection(
parent_terminal=self, session_id=session_id, connection_uuid=connection_id
)
self._connections[connection_id] = client_connection
self._client_connection_requests[connection_request_id] = client_connection
def _process_remote_login(self, payload: SSHPacket, session_id: str) -> bool:
"""Processes a remote terminal requesting to login to this terminal.
:param payload: The SSH Payload Packet.
:param session_id: Session ID for sending login response.
:return: True if successful, else False.
def receive(self, session_id: str, payload: Any, **kwargs) -> bool:
"""
username: str = payload.user_account.username
password: str = payload.user_account.password
self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}")
# TODO: Un-comment this when UserSessionManager is merged.
# connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password)
connection_uuid = str(uuid4())
if connection_uuid:
# Send uuid to remote
self.sys_log.info(
f"Remote login authorised, connection ID {connection_uuid} for " f"{username} in session {session_id}"
)
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
return_payload = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
connection_uuid=connection_uuid,
)
self._add_new_connection(connection_uuid=connection_uuid, session_id=session_id)
Receive a payload from the Software Manager.
self.send(payload=return_payload, session_id=session_id)
return True
else:
# UserSessionManager has returned None
self.sys_log.warning("Login failed, incorrect Username or Password")
return False
def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool:
"""Receive Payload and process for a response.
:param payload: The message contents received.
:param session_id: Session ID of received message.
:return: True if successful, else False.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
self.sys_log.debug(f"Received payload: {payload}")
self.sys_log.info(f"Received payload: {payload}")
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "login_request":
# add connection
connection_request_id = payload["connection_request_id"]
username = payload["username"]
password = payload["password"]
print(f"Connection ID is: {connection_request_id}")
self.sys_log.info(f"Connection authorised, session_id: {session_id}")
self._create_remote_connection(
connection_id=connection_request_id,
connection_request_id=payload["connection_request_id"],
session_id=session_id,
)
payload = {
"type": "login_success",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_port=self.port, session_id=session_id
)
elif payload["type"] == "login_success":
self.sys_log.info(f"Login was successful! session_id is:{session_id}")
connection_request_id = payload["connection_request_id"]
self._create_remote_connection(
connection_id=connection_request_id,
session_id=session_id,
connection_request_id=connection_request_id,
)
if not isinstance(payload, SSHPacket):
return False
elif payload["type"] == "disconnect":
connection_id = payload["connection_id"]
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server")
self._disconnect(payload["connection_id"])
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning("Cannot process message as not running")
return False
if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE:
# Close the channel
connection_id = kwargs["connection_id"]
dest_ip_address = kwargs["dest_ip_address"]
self.disconnect(dest_ip_address=dest_ip_address)
self.sys_log.debug(f"Disconnecting {connection_id}")
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
return self._process_remote_login(payload=payload, session_id=session_id)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}")
return True
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
return self.execute(command=payload.payload)
else:
self.sys_log.warning("Encounter unexpected message type, rejecting connection")
return False
if isinstance(payload, list):
# A request? For me?
self.execute(payload)
return True
def execute(self, command: List[Any]) -> RequestResponse:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
def _disconnect(self, connection_uuid: str) -> bool:
"""Disconnect from the remote.
@@ -309,30 +317,16 @@ class Terminal(Service):
self.sys_log.warning("No remote connection present")
return False
dest_ip_address = self._connections[connection_uuid].dest_ip_address
session_id = self._connections[connection_uuid].session_id
self._connections.pop(connection_uuid)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_uuid},
dest_ip_address=dest_ip_address,
dest_port=self.port,
payload={"type": "disconnect", "connection_id": connection_uuid}, dest_port=self.port, session_id=session_id
)
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
return True
def disconnect(self, connection_uuid: Optional[str]) -> bool:
"""Disconnect the terminal.
If no connection id has been supplied, disconnects the first connection.
:param connection_uuid: Connection ID that we want to disconnect.
:return: True if successful, False otherwise.
"""
if not connection_uuid:
connection_uuid = next(iter(self._connections))
return self._disconnect(connection_uuid=connection_uuid)
def send(
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
) -> bool:
@@ -345,6 +339,7 @@ class Terminal(Service):
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!")
return False
self.sys_log.debug(f"Sending payload: {payload}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id

View File

@@ -16,7 +16,7 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
@@ -87,8 +87,6 @@ def test_terminal_send(basic_network):
payload="Test_Payload",
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
sender_ip_address=computer_a.network_interface[1].ip_address,
target_ip_address=computer_b.network_interface[1].ip_address,
)
assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
@@ -106,11 +104,13 @@ def test_terminal_receive(basic_network):
payload=["file_system", "create", "folder", folder_name],
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
sender_ip_address=computer_a.network_interface[1].ip_address,
target_ip_address=computer_b.network_interface[1].ip_address,
)
assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
username="username", password="password", ip_address="192.168.0.11"
)
term_a_on_node_b.execute(["file_system", "create", "folder", folder_name])
# Assert that the Folder has been correctly created
assert computer_b.file_system.get_folder(folder_name)
@@ -127,11 +127,13 @@ def test_terminal_install(basic_network):
payload=["software_manager", "application", "install", "RansomwareScript"],
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
sender_ip_address=computer_a.network_interface[1].ip_address,
target_ip_address=computer_b.network_interface[1].ip_address,
)
terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
username="username", password="password", ip_address="192.168.0.11"
)
term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"])
assert computer_b.software_manager.software.get("RansomwareScript")
@@ -145,29 +147,30 @@ def test_terminal_fail_when_closed(basic_network):
terminal.operating_state = ServiceOperatingState.STOPPED
assert (
terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address)
is False
assert not terminal.login(
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
)
def test_terminal_disconnect(basic_network):
"""Terminal should set is_connected to false on disconnect"""
"""Test Terminal disconnects"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
assert terminal_a.is_connected is False
assert len(terminal_b._connections) == 0
terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address)
term_a_on_term_b = terminal_a.login(
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
)
assert terminal_a.is_connected is True
assert len(terminal_b._connections) == 1
terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address)
term_a_on_term_b.disconnect()
assert terminal_a.is_connected is False
assert len(terminal_b._connections) == 0
def test_terminal_ignores_when_off(basic_network):
@@ -178,21 +181,13 @@ def test_terminal_ignores_when_off(basic_network):
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") # login to computer_b
assert terminal_a.is_connected is True
term_a_on_term_b: RemoteTerminalConnection = terminal_a.login(
username="admin", password="Admin123!", ip_address="192.168.0.11"
) # login to computer_b
terminal_a.operating_state = ServiceOperatingState.STOPPED
payload: SSHPacket = SSHPacket(
payload="Test_Payload",
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
sender_ip_address=computer_a.network_interface[1].ip_address,
target_ip_address="192.168.0.11",
)
assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11")
assert not term_a_on_term_b.execute(["software_manager", "application", "install", "RansomwareScript"])
def test_network_simulation(basic_network):