Merge dev into feature branch

This commit is contained in:
Marek Wolan
2024-08-12 09:02:11 +01:00
16 changed files with 459 additions and 95 deletions

View File

@@ -9,13 +9,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
- Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes.
- Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the
main port they're assigned.
### Changed
- Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
- File and folder observations can now be configured to always show the true health status, or require scanning like before.
### Fixed
- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config)
- Updated `SoftwareManager` `install` and `uninstall` to handle all functionality that was being done at the `install`
and `uninstall` methods in the `Node` class.
- Updated the `receive_payload_from_session_manager` method in `SoftwareManager` so that it now sends a copy of the
payload to any software listening on the destination port of the `Frame`.
### Removed
- Removed the `install` and `uninstall` methods in the `Node` class.
## [3.2.0] - 2024-07-18

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Final, Tuple
from report import build_benchmark_md_report
from report import build_benchmark_md_report, md2pdf
from stable_baselines3 import PPO
import primaite
@@ -159,6 +159,13 @@ def run(
learning_rate: float = 3e-4,
) -> None:
"""Run the PrimAITE benchmark."""
# generate report folder
v_str = f"v{primaite.__version__}"
version_result_dir = _RESULTS_ROOT / v_str
version_result_dir.mkdir(exist_ok=True, parents=True)
output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md"
benchmark_start_time = datetime.now()
session_metadata_dict = {}
@@ -193,6 +200,12 @@ def run(
session_metadata=session_metadata_dict,
config_path=data_manipulation_config_path(),
results_root_path=_RESULTS_ROOT,
output_path=output_path,
)
md2pdf(
md_path=output_path,
pdf_path=str(output_path).replace(".md", ".pdf"),
css_path="static/styles.css",
)

View File

@@ -2,6 +2,7 @@
import json
import sys
from datetime import datetime
from os import PathLike
from pathlib import Path
from typing import Dict, Optional
@@ -14,7 +15,7 @@ from utils import _get_system_info
import primaite
PLOT_CONFIG = {
"size": {"auto_size": False, "width": 1500, "height": 900},
"size": {"auto_size": False, "width": 800, "height": 640},
"template": "plotly_white",
"range_slider": False,
}
@@ -144,6 +145,20 @@ def _plot_benchmark_metadata(
yaxis={"title": "Total Reward"},
title=title,
)
fig.update_layout(
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.3)",
)
)
for trace in fig["data"]:
if trace["name"].startswith("Session"):
trace["showlegend"] = False
fig["data"][0]["name"] = "Individual Sessions"
fig["data"][0]["showlegend"] = True
return fig
@@ -194,6 +209,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
title=title,
)
fig["data"][0]["showlegend"] = True
fig.update_layout(legend=dict(yanchor="top", y=-0.2, xanchor="left", x=0.01, orientation="h"))
return fig
@@ -248,14 +264,7 @@ def _plot_av_s_per_100_steps_10_nodes(
versions = sorted(list(version_times_dict.keys()))
times = [version_times_dict[version] for version in versions]
fig.add_trace(
go.Bar(
x=versions,
y=times,
text=times,
textposition="auto",
)
)
fig.add_trace(go.Bar(x=versions, y=times, text=times, textposition="auto", texttemplate="%{y:.3f}"))
fig.update_layout(
xaxis_title="PrimAITE Version",
@@ -267,7 +276,11 @@ def _plot_av_s_per_100_steps_10_nodes(
def build_benchmark_md_report(
benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path
benchmark_start_time: datetime,
session_metadata: Dict,
config_path: Path,
results_root_path: Path,
output_path: PathLike,
) -> None:
"""
Generates a Markdown report for a benchmarking session, documenting performance metrics and graphs.
@@ -319,7 +332,7 @@ def build_benchmark_md_report(
data = benchmark_metadata_dict
primaite_version = data["primaite_version"]
with open(version_result_dir / f"PrimAITE v{primaite_version} Benchmark Report.md", "w") as file:
with open(output_path, "w") as file:
# Title
file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n")
file.write("## PrimAITE Dev Team\n")
@@ -393,3 +406,15 @@ def build_benchmark_md_report(
f"![Performance of Minor and Bugfix Releases for Major Version {major_v}]"
f"({performance_benchmark_plot_path.name})\n"
)
def md2pdf(md_path: PathLike, pdf_path: PathLike, css_path: PathLike) -> None:
"""Generate PDF version of Markdown report."""
from md2pdf.core import md2pdf
md2pdf(
pdf_file_path=pdf_path,
md_file_path=md_path,
base_url=Path(md_path).parent,
css_file_path=css_path,
)

View File

@@ -0,0 +1,34 @@
body {
font-family: 'Arial', sans-serif;
line-height: 1.6;
/* margin: 1cm; */
}
h1, h2, h3, h4, h5, h6 {
font-weight: bold;
/* margin: 1em 0; */
}
p {
/* margin: 0.5em 0; */
}
ul, ol {
margin: 1em 0;
padding-left: 1.5em;
}
pre {
background: #f4f4f4;
padding: 0.5em;
overflow-x: auto;
}
img {
max-width: 100%;
height: auto;
}
table {
width: 100%;
border-collapse: collapse;
margin: 1em 0;
}
th, td {
padding: 0.5em;
border: 1px solid #ddd;
}

View File

@@ -25,3 +25,35 @@ The configuration options are the attributes that fall under the options for an
Optional. Default value is ``2``.
The number of timesteps the |SOFTWARE_NAME| will remain in a ``FIXING`` state before going into a ``GOOD`` state.
``listen_on_ports``
"""""""""""""""""""
The set of ports to listen on. This is in addition to the main port the software is designated. This set can either be
the string name of ports or the port integers
Example:
.. code-block:: yaml
simulation:
network:
nodes:
- hostname: client
type: computer
ip_address: 192.168.10.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
services:
- type: DatabaseService
options:
backup_server_ip: 10.10.1.12
listen_on_ports:
- 631
applications:
- type: WebBrowser
options:
target_url: http://sometech.ai
listen_on_ports:
- SMB

View File

@@ -75,7 +75,8 @@ dev = [
"wheel==0.38.4",
"nbsphinx==0.9.4",
"nbmake==1.5.4",
"pytest-xdist==3.3.1"
"pytest-xdist==3.3.1",
"md2pdf",
]
[project.scripts]

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import numpy as np
from pydantic import BaseModel, ConfigDict
@@ -44,8 +44,10 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import Software
_LOGGER = getLogger(__name__)
@@ -328,6 +330,20 @@ class PrimaiteGame:
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
def _set_software_listen_on_ports(software: Union[Software, Service], software_cfg: Dict):
"""Set listener ports on software."""
listen_on_ports = []
for port_id in set(software_cfg.get("options", {}).get("listen_on_ports", [])):
port = None
if isinstance(port_id, int):
port = Port(port_id)
elif isinstance(port_id, str):
port = Port[port_id]
if port:
listen_on_ports.append(port)
software.listen_on_ports = set(listen_on_ports)
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
@@ -341,6 +357,7 @@ class PrimaiteGame:
if "fix_duration" in service_cfg.get("options", {}):
new_service.fixing_duration = service_cfg["options"]["fix_duration"]
_set_software_listen_on_ports(new_service, service_cfg)
# start the service
new_service.start()
else:
@@ -390,6 +407,8 @@ class PrimaiteGame:
_LOGGER.error(msg)
raise ValueError(msg)
_set_software_listen_on_ports(new_application, application_cfg)
# run the application
new_application.run()

View File

@@ -80,14 +80,14 @@
"outputs": [],
"source": [
"# Login to the remote (node_b) from local (node_a)\n",
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=\"192.168.0.11\")"
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"admin\", ip_address=\"192.168.0.11\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can view all active connections to a terminal through use of the `show()` method"
"You can view all active connections to a terminal through use of the `show()` method."
]
},
{
@@ -180,9 +180,24 @@
"term_a_term_b_remote_connection.disconnect()\n",
"\n",
"terminal_a.show()\n",
"\n",
"terminal_b.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Disconnected Terminal sessions will no longer show in the node's Terminal connection list, but will be under the historic sessions in the `user_session_manager`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
]
}
],
"metadata": {

View File

@@ -1174,7 +1174,7 @@ class UserSessionManager(Service):
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
# todo add doc about request schemas
rm.add_request(
"remote_login",
RequestType(
@@ -1278,6 +1278,10 @@ class UserSessionManager(Service):
if self.local_session:
if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep:
self._timeout_session(self.local_session)
for session in self.remote_sessions:
remote_session = self.remote_sessions[session]
if remote_session.last_active_step + self.remote_session_timeout_steps <= timestep:
self._timeout_session(remote_session)
def _timeout_session(self, session: UserSession) -> None:
"""
@@ -1294,6 +1298,13 @@ class UserSessionManager(Service):
self.remote_sessions.pop(session.uuid)
session_type = "Remote"
session_identity = f"{session_identity} {session.remote_ip_address}"
self.parent.terminal._connections.pop(session.uuid)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "user_timeout", "connection_id": session.uuid},
dest_port=Port.SSH,
dest_ip_address=session.remote_ip_address,
)
self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity")

View File

@@ -1,4 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from copy import deepcopy
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
@@ -76,6 +77,8 @@ class SoftwareManager:
for software in self.port_protocol_mapping.values():
if software.operating_state in {ApplicationOperatingState.RUNNING, ServiceOperatingState.RUNNING}:
open_ports.append(software.port)
if software.listen_on_ports:
open_ports += list(software.listen_on_ports)
return open_ports
def check_port_is_open(self, port: Port, protocol: IPProtocol) -> bool:
@@ -223,7 +226,9 @@ class SoftwareManager:
frame: Frame,
):
"""
Receive a payload from the SessionManager and forward it to the corresponding service or application.
Receive a payload from the SessionManager and forward it to the corresponding service or applications.
This function handles both software assigned a specific port, and software listening in on other ports.
:param payload: The payload being received.
:param session: The transport session the payload originates from.
@@ -231,14 +236,25 @@ class SoftwareManager:
if payload.__class__.__name__ == "PortScanPayload":
self.software.get("NMAP").receive(payload=payload, session_id=session_id)
return
receiver: Optional[Union[Service, Application]] = self.port_protocol_mapping.get((port, protocol), None)
if receiver:
receiver.receive(
main_receiver = self.port_protocol_mapping.get((port, protocol), None)
if main_receiver:
main_receiver.receive(
payload=payload, session_id=session_id, from_network_interface=from_network_interface, frame=frame
)
else:
listening_receivers = [
software
for software in self.software.values()
if port in software.listen_on_ports and software != main_receiver
]
for receiver in listening_receivers:
receiver.receive(
payload=deepcopy(payload),
session_id=session_id,
from_network_interface=from_network_interface,
frame=frame,
)
if not main_receiver and not listening_receivers:
self.sys_log.warning(f"No service or application found for port {port} and protocol {protocol}")
pass
def show(self, markdown: bool = False):
"""

View File

@@ -385,6 +385,8 @@ class DatabaseService(Service):
)
else:
result = {"status_code": 401, "type": "sql"}
else:
self.sys_log.info(f"{self.name}: Ignoring payload as it is not a Database payload")
self.send(payload=result, session_id=session_id)
return True

View File

@@ -33,8 +33,8 @@ class TerminalClientConnection(BaseModel):
parent_terminal: Terminal
"""The parent Node that this connection was created on."""
session_id: str = None
"""Session ID that connection is linked to"""
ssh_session_id: str = None
"""Session ID that connection is linked to, used for sending commands via session manager."""
connection_uuid: str = None
"""Connection UUID"""
@@ -52,7 +52,7 @@ class TerminalClientConnection(BaseModel):
"""Flag to state whether the connection is active or not"""
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')"
return f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ip_address: {self.ip_address}')"
def __repr__(self) -> str:
return self.__str__()
@@ -124,13 +124,14 @@ class RemoteTerminalConnection(TerminalClientConnection):
ssh_command=command,
)
return self.parent_terminal.send(payload=payload, session_id=self.session_id)
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
class Terminal(Service):
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
"""Dictionary of connect requests made to remote nodes."""
def __init__(self, **kwargs):
kwargs["name"] = "Terminal"
@@ -169,31 +170,50 @@ class Terminal(Service):
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._process_local_login(username=request[0], password=request[1])
if login:
return RequestResponse(status="success", data={})
return RequestResponse(
status="success",
data={
"ip_address": login.ip_address,
},
)
else:
return RequestResponse(status="failure", data={})
return RequestResponse(status="failure", data={"reason": "Invalid login credentials"})
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
if login:
return RequestResponse(status="success", data={})
return RequestResponse(
status="success",
data={
"ip_address": login.ip_address,
},
)
else:
return RequestResponse(status="failure", data={})
def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
command: str = request[0]
connection_id: str = request[1]
self.execute(command, connection_id=connection_id)
return RequestResponse(status="success", data={})
ip_address: IPv4Address = IPv4Address(request[1])
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={},
)
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from connection."""
connection_uuid = request[0]
# TODO: Uncomment this when UserSessionManager merged.
# self.parent.UserSessionManager.logoff(connection_uuid)
self.parent.user_session_manager.local_logout(connection_uuid)
self._disconnect(connection_uuid)
return RequestResponse(status="success", data={})
rm.add_request(
@@ -208,21 +228,22 @@ class Terminal(Service):
rm.add_request(
"Execute",
request_type=RequestType(func=_execute_request),
request_type=RequestType(func=remote_execute_request),
)
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
return rm
def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]:
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
valid_connection = self._check_client_connection(connection_id=connection_id)
if valid_connection:
return self.parent.apply_request(command)
else:
self.sys_log.error("Invalid connection ID provided")
return None
return self.parent.apply_request(command)
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
"""Find Remote Terminal Connection from a given IP."""
for connection in self._connections.values():
if connection.ip_address == ip_address:
return connection
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
"""Create a new connection object and amend to list of active connections.
@@ -234,7 +255,7 @@ class Terminal(Service):
new_connection = LocalTerminalConnection(
parent_terminal=self,
connection_uuid=connection_uuid,
session_id=session_id,
ssh_session_id=session_id,
time=datetime.now(),
)
self._connections[connection_uuid] = new_connection
@@ -257,7 +278,7 @@ class Terminal(Service):
:type: ip_address: Optional[IPv4Address]
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning("Cannot login as service is not running.")
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
@@ -277,23 +298,22 @@ class Terminal(Service):
:return: boolean, True if successful, else False
"""
# TODO: Un-comment this when UserSessionManager is merged.
# connection_uuid = self.parent.UserSessionManager.login(username=username, password=password)
connection_uuid = str(uuid4())
connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password)
if connection_uuid:
self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}")
self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}")
# Add new local session to list of connections and return
return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection")
else:
self.sys_log.warning("Login failed, incorrect Username or Password")
self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password")
return None
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
return connection_id in self._client_connection_requests
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._connections else False
return connection_id in self._connections
def _send_remote_login(
self,
@@ -323,7 +343,9 @@ class Terminal(Service):
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
"""
self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}")
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
@@ -332,7 +354,7 @@ class Terminal(Service):
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
return remote_terminal_connection
else:
self.sys_log.warning(f"Connection request{connection_request_id} declined")
self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined")
return None
else:
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
@@ -382,7 +404,7 @@ class Terminal(Service):
"""
client_connection = RemoteTerminalConnection(
parent_terminal=self,
session_id=session_id,
ssh_session_id=session_id,
connection_uuid=connection_id,
ip_address=source_ip,
connection_request_id=connection_request_id,
@@ -399,20 +421,20 @@ class Terminal(Service):
:param session_id: The session id the payload relates to.
:return: True.
"""
source_ip = kwargs["from_network_interface"].ip_address
self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}")
source_ip = kwargs["frame"].ip.src_ip_address
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
if isinstance(payload, SSHPacket):
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate & add connection
# TODO: uncomment this as part of 2781
# connection_id = self.parent.UserSessionManager.login(username=username, password=password)
connection_id = str(uuid4())
username = payload.user_account.username
password = payload.user_account.password
connection_id = self.parent.user_session_manager.remote_login(
username=username, password=password, remote_ip_address=source_ip
)
if connection_id:
connection_request_id = payload.connection_request_uuid
username = payload.user_account.username
password = payload.user_account.password
print(f"Connection ID is: {connection_request_id}")
self.sys_log.info(f"Connection authorised, session_id: {session_id}")
self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}")
self._create_remote_connection(
connection_id=connection_id,
connection_request_id=connection_request_id,
@@ -443,7 +465,7 @@ class Terminal(Service):
payload=payload, dest_port=self.port, session_id=session_id
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
self.sys_log.info("Login Successful")
self.sys_log.info(f"{self.name}: Login Successful")
self._create_remote_connection(
connection_id=payload.connection_uuid,
connection_request_id=payload.connection_request_uuid,
@@ -453,14 +475,16 @@ class Terminal(Service):
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
self.sys_log.info("Received command to execute")
self.sys_log.info(f"{self.name}: Received command to execute")
command = payload.ssh_command
valid_connection = self._check_client_connection(payload.connection_uuid)
self.sys_log.info(f"Connection uuid is {valid_connection}")
if valid_connection:
return self.execute(command, payload.connection_uuid)
self.execute(command)
return True
else:
self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.")
self.sys_log.error(
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
)
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":
@@ -469,19 +493,30 @@ class Terminal(Service):
if valid_id:
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.")
self._disconnect(payload["connection_id"])
self.parent.user_session_manager.remote_logout(remote_session_id=connection_id)
else:
self.sys_log.info("No Active connection held for received connection ID.")
self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.")
if payload["type"] == "user_timeout":
connection_id = payload["connection_id"]
valid_id = self._check_client_connection(connection_id)
if valid_id:
connection = self._connections.pop(connection_id)
connection.is_active = False
self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.")
else:
self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.")
return True
def _disconnect(self, connection_uuid: str) -> bool:
"""Disconnect from the remote.
"""Disconnect connection.
:param connection_uuid: Connection ID that we want to disconnect.
:return True if successful, False otherwise.
"""
if not self._connections:
self.sys_log.warning("No remote connection present")
self.sys_log.warning(f"{self.name}: No remote connection present")
return False
connection = self._connections.pop(connection_uuid)
@@ -489,7 +524,7 @@ class Terminal(Service):
if isinstance(connection, RemoteTerminalConnection):
# Send disconnect command via software manager
session_id = connection.session_id
session_id = connection.ssh_session_id
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
@@ -501,7 +536,7 @@ class Terminal(Service):
return True
elif isinstance(connection, LocalTerminalConnection):
# No further action needed
self.parent.user_session_manager.local_logout()
return True
def send(
@@ -514,10 +549,10 @@ class Terminal(Service):
:param dest_up_address: The IP address of the payload destination.
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!")
self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!")
return False
self.sys_log.debug(f"Sending payload: {payload}")
self.sys_log.debug(f"{self.name}: Sending payload: {payload}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
)

View File

@@ -4,9 +4,10 @@ from abc import abstractmethod
from datetime import datetime
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import Field
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
@@ -252,6 +253,8 @@ class IOSoftware(Software):
"Indicates if the software uses UDP protocol for communication. Default is True."
port: Port
"The port to which the software is connected."
listen_on_ports: Set[Port] = Field(default_factory=set)
"The set of ports to listen on."
protocol: IPProtocol
"The IP Protocol the Software operates on."
_connections: Dict[str, Dict] = {}

View File

@@ -0,0 +1,39 @@
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
agent_log_level: INFO
save_agent_logs: true
write_agent_log_to_terminal: True
game:
max_episode_length: 256
ports:
- ARP
protocols:
- ICMP
- UDP
simulation:
network:
nodes:
- hostname: client
type: computer
ip_address: 192.168.10.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
services:
- type: DatabaseService
options:
backup_server_ip: 10.10.1.12
listen_on_ports:
- 631
applications:
- type: WebBrowser
options:
target_url: http://sometech.ai
listen_on_ports:
- SMB

View File

@@ -0,0 +1,84 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Any, Dict, List, Set
import yaml
from pydantic import Field
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.service import Service
from tests import TEST_ASSETS_ROOT
class _DatabaseListener(Service):
name: str = "DatabaseListener"
protocol: IPProtocol = IPProtocol.TCP
port: Port = Port.NONE
listen_on_ports: Set[Port] = {Port.POSTGRES_SERVER}
payloads_received: List[Any] = Field(default_factory=list)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
self.payloads_received.append(payload)
self.sys_log.info(f"{self.name}: received payload {payload}")
return True
def describe_state(self) -> Dict:
return super().describe_state()
def test_http_listener(client_server):
computer, server = client_server
server.software_manager.install(DatabaseService)
server_db = server.software_manager.software["DatabaseService"]
server_db.start()
server.software_manager.install(_DatabaseListener)
server_db_listener: _DatabaseListener = server.software_manager.software["DatabaseListener"]
server_db_listener.start()
computer.software_manager.install(DatabaseClient)
computer_db_client: DatabaseClient = computer.software_manager.software["DatabaseClient"]
computer_db_client.run()
computer_db_client.server_ip_address = server.network_interface[1].ip_address
assert len(server_db_listener.payloads_received) == 0
computer.session_manager.receive_payload_from_software_manager(
payload="masquerade as Database traffic",
dst_ip_address=server.network_interface[1].ip_address,
dst_port=Port.POSTGRES_SERVER,
ip_protocol=IPProtocol.TCP,
)
assert len(server_db_listener.payloads_received) == 1
db_connection = computer_db_client.get_new_connection()
assert db_connection
assert len(server_db_listener.payloads_received) == 2
assert db_connection.query("SELECT")
assert len(server_db_listener.payloads_received) == 3
def test_set_listen_on_ports_from_config():
config_path = TEST_ASSETS_ROOT / "configs" / "basic_node_with_software_listening_ports.yaml"
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
client: Computer = network.get_node_by_hostname("client")
assert Port.SMB in client.software_manager.get_open_ports()
assert Port.IPP in client.software_manager.get_open_ports()
web_browser = client.software_manager.software["WebBrowser"]
assert not web_browser.listen_on_ports.difference({Port.SMB, Port.IPP})

View File

@@ -185,7 +185,7 @@ def test_terminal_receive(basic_network):
)
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
username="username", password="password", ip_address="192.168.0.11"
username="admin", password="admin", ip_address="192.168.0.11"
)
term_a_on_node_b.execute(["file_system", "create", "folder", folder_name])
@@ -208,7 +208,7 @@ def test_terminal_install(basic_network):
)
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
username="username", password="password", ip_address="192.168.0.11"
username="admin", password="admin", ip_address="192.168.0.11"
)
term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"])
@@ -225,9 +225,7 @@ def test_terminal_fail_when_closed(basic_network):
terminal.operating_state = ServiceOperatingState.STOPPED
assert not terminal.login(
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
)
assert not terminal.login(username="admin", password="admin", ip_address=computer_b.network_interface[1].ip_address)
def test_terminal_disconnect(basic_network):
@@ -241,7 +239,7 @@ def test_terminal_disconnect(basic_network):
assert len(terminal_b._connections) == 0
term_a_on_term_b = terminal_a.login(
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
username="admin", password="admin", ip_address=computer_b.network_interface[1].ip_address
)
assert len(terminal_b._connections) == 1
@@ -250,6 +248,8 @@ def test_terminal_disconnect(basic_network):
assert len(terminal_b._connections) == 0
assert term_a_on_term_b.is_active is False
def test_terminal_ignores_when_off(basic_network):
"""Terminal should ignore commands when not running"""
@@ -260,7 +260,7 @@ def test_terminal_ignores_when_off(basic_network):
computer_b: Computer = network.get_node_by_hostname("node_b")
term_a_on_term_b: RemoteTerminalConnection = terminal_a.login(
username="admin", password="Admin123!", ip_address="192.168.0.11"
username="admin", password="admin", ip_address="192.168.0.11"
) # login to computer_b
terminal_a.operating_state = ServiceOperatingState.STOPPED
@@ -276,7 +276,7 @@ def test_computer_remote_login_to_router(wireless_wan_network):
assert len(pc_a_terminal._connections) == 0
pc_a_on_router_1 = pc_a_terminal.login(username="username", password="password", ip_address="192.168.1.1")
pc_a_on_router_1 = pc_a_terminal.login(username="admin", password="admin", ip_address="192.168.1.1")
assert len(pc_a_terminal._connections) == 1
@@ -295,7 +295,7 @@ def test_router_remote_login_to_computer(wireless_wan_network):
assert len(router_1_terminal._connections) == 0
router_1_on_pc_a = router_1_terminal.login(username="username", password="password", ip_address="192.168.0.2")
router_1_on_pc_a = router_1_terminal.login(username="admin", password="admin", ip_address="192.168.0.2")
assert len(router_1_terminal._connections) == 1
@@ -317,7 +317,7 @@ def test_router_blocks_SSH_traffic(wireless_wan_network):
assert len(pc_a_terminal._connections) == 0
pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2")
pc_a_terminal.login(username="admin", password="admin", ip_address="192.168.0.2")
assert len(pc_a_terminal._connections) == 0
@@ -333,7 +333,7 @@ def test_SSH_across_network(wireless_wan_network):
assert len(terminal_a._connections) == 0
terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2")
terminal_b_on_terminal_a = terminal_b.login(username="admin", password="admin", ip_address="192.168.0.2")
assert len(terminal_a._connections) == 1
@@ -347,11 +347,13 @@ def test_multiple_remote_terminals_same_node(basic_network):
assert len(terminal_a._connections) == 0
# Spam login requests to terminal.
for attempt in range(10):
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
# Spam login requests to node.
for attempt in range(3):
remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 10
terminal_a.show()
assert len(terminal_a._connections) == 3
def test_terminal_rejects_commands_if_disconnect(basic_network):
@@ -363,7 +365,7 @@ def test_terminal_rejects_commands_if_disconnect(basic_network):
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 1
assert len(terminal_b._connections) == 1
@@ -378,3 +380,27 @@ def test_terminal_rejects_commands_if_disconnect(basic_network):
assert not computer_b.software_manager.software.get("RansomwareScript")
assert remote_connection.is_active is False
def test_terminal_connection_timeout(basic_network):
"""Test that terminal_connections are affected by UserSession timeout."""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 1
assert len(terminal_b._connections) == 1
assert len(computer_b.user_session_manager.remote_sessions) == 1
remote_session = computer_b.user_session_manager.remote_sessions[remote_connection.connection_uuid]
computer_b.user_session_manager._timeout_session(remote_session)
assert len(terminal_a._connections) == 0
assert len(terminal_b._connections) == 0
assert len(computer_b.user_session_manager.remote_sessions) == 0
assert not remote_connection.is_active