Merge remote-tracking branch 'origin/dev' into feature/2769-implement-user-account-action-space

This commit is contained in:
Marek Wolan
2024-08-09 11:12:40 +01:00
21 changed files with 1632 additions and 34 deletions

View File

@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -44,6 +44,7 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -57,6 +58,7 @@ SERVICE_TYPES_MAPPING = {
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
"Terminal": Terminal,
}
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
@@ -70,6 +72,8 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
seed: int = None
"""Random number seed for RNGs."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]

View File

@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Terminal Processing\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n",
"\n",
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.system.services.terminal.terminal import Terminal\n",
"from primaite.simulator.network.container import Network\n",
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n",
"from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n",
"\n",
"def basic_network() -> Network:\n",
" \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n",
" network = Network()\n",
" node_a = Computer(hostname=\"node_a\", ip_address=\"192.168.0.10\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
" node_a.power_on()\n",
" node_b = Computer(hostname=\"node_b\", ip_address=\"192.168.0.11\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
" node_b.power_on()\n",
" network.connect(node_a.network_interface[1], node_b.network_interface[1])\n",
" return network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The terminal can be accessed from a `Node` via the `software_manager` as demonstrated below. \n",
"\n",
"In the example, we have a basic network consisting of two computers, connected to form a basic network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"network: Network = basic_network()\n",
"computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n",
"terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n",
"computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n",
"terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are remotely logging in to the 'admin' account on `node_b`, from `node_a`. \n",
"If you are not logged in, any commands sent will be rejected by the remote.\n",
"\n",
"Remote Logins return a RemoteTerminalConnection object, which can be used for sending commands to the remote node. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Login to the remote (node_b) from local (node_a)\n",
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"admin\", ip_address=\"192.168.0.11\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can view all active connections to a terminal through use of the `show()` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminal_b.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The new connection object allows us to forward commands to be executed on the target node. The example below demonstrates how you can remotely install an application on the target node."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.software_manager.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to create a downloads folder. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display the current state of the file system on computer_b\n",
"computer_b.file_system.show()\n",
"\n",
"# Send command\n",
"term_a_term_b_remote_connection.execute([\"file_system\", \"create\", \"folder\", \"downloads\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The resultant call to `computer_b.file_system.show()` shows that the new folder has been created."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.file_system.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When finished, the connection can be closed by calling the `disconnect` function of the Remote Client object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display active connection\n",
"terminal_a.show()\n",
"terminal_b.show()\n",
"\n",
"term_a_term_b_remote_connection.disconnect()\n",
"\n",
"terminal_a.show()\n",
"terminal_b.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Disconnected Terminal sessions will no longer show in the node's Terminal connection list, but will be under the historic sessions in the `user_session_manager`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
import random
import sys
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
# Check torch is installed
try:
import torch as th
except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
"""
Set random number generators.
:param seed: int
"""
if seed is None or seed == -1:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# Seed the RNG for all devices (both CPU and CUDA)
# if torch not installed don't set random seed.
if sys.modules["torch"]:
th.manual_seed(seed)
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
return seed
class PrimaiteGymEnv(gymnasium.Env):
"""
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}

View File

@@ -30,6 +30,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.software import IOSoftware, Software
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
from primaite.utils.validators import IPV4Address
@@ -1173,7 +1174,7 @@ class UserSessionManager(Service):
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
# todo add doc about request schemas
rm.add_request(
"remote_login",
RequestType(
@@ -1277,6 +1278,10 @@ class UserSessionManager(Service):
if self.local_session:
if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep:
self._timeout_session(self.local_session)
for session in self.remote_sessions:
remote_session = self.remote_sessions[session]
if remote_session.last_active_step + self.remote_session_timeout_steps <= timestep:
self._timeout_session(remote_session)
def _timeout_session(self, session: UserSession) -> None:
"""
@@ -1293,6 +1298,13 @@ class UserSessionManager(Service):
self.remote_sessions.pop(session.uuid)
session_type = "Remote"
session_identity = f"{session_identity} {session.remote_ip_address}"
self.parent.terminal._connections.pop(session.uuid)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "user_timeout", "connection_id": session.uuid},
dest_port=Port.SSH,
dest_ip_address=session.remote_ip_address,
)
self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity")
@@ -1541,6 +1553,11 @@ class Node(SimComponent):
"""The Nodes User Session Manager."""
return self.software_manager.software.get("UserSessionManager") # noqa
@property
def terminal(self) -> Optional[Terminal]:
"""The Nodes Terminal."""
return self.software_manager.software.get("Terminal")
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Attempt to log in to the node uas a local user.

View File

@@ -21,6 +21,7 @@ from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.utils.validators import IPV4Address
_LOGGER = getLogger(__name__)
@@ -298,6 +299,7 @@ class HostNode(Node):
* DNS (Domain Name System) Client: Resolves domain names to IP addresses.
* FTP (File Transfer Protocol) Client: Enables file transfers between the host and FTP servers.
* NTP (Network Time Protocol) Client: Synchronizes the system clock with NTP servers.
* Terminal Client: Handles SSH requests between HostNode and external components.
Applications:
------------
@@ -314,6 +316,7 @@ class HostNode(Node):
"NMAP": NMAP,
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
"""List of system software that is automatically installed on nodes."""

View File

@@ -24,6 +24,7 @@ from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.utils.validators import IPV4Address
@@ -1203,6 +1204,7 @@ class Router(NetworkNode):
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
num_ports: int

View File

@@ -0,0 +1,89 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import IntEnum
from typing import Optional
from primaite.interface.request import RequestResponse
from primaite.simulator.network.protocols.packet import DataPacket
class SSHTransportMessage(IntEnum):
"""
Enum list of Transport layer messages that can be handled by the simulation.
Each msg value is equivalent to the real-world.
"""
SSH_MSG_USERAUTH_REQUEST = 50
"""Requests User Authentication."""
SSH_MSG_USERAUTH_FAILURE = 51
"""Indicates User Authentication failed."""
SSH_MSG_USERAUTH_SUCCESS = 52
"""Indicates User Authentication was successful."""
SSH_MSG_SERVICE_REQUEST = 24
"""Requests a service - such as executing a command."""
# These two msgs are invented for primAITE however are modelled on reality
SSH_MSG_SERVICE_FAILED = 25
"""Indicates that the requested service failed."""
SSH_MSG_SERVICE_SUCCESS = 26
"""Indicates that the requested service was successful."""
class SSHConnectionMessage(IntEnum):
"""Int Enum list of all SSH's connection protocol messages that can be handled by the simulation."""
SSH_MSG_CHANNEL_OPEN = 80
"""Requests an open channel - Used in combination with SSH_MSG_USERAUTH_REQUEST."""
SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 81
"""Confirms an open channel."""
SSH_MSG_CHANNEL_OPEN_FAILED = 82
"""Indicates that channel opening failure."""
SSH_MSG_CHANNEL_DATA = 84
"""Indicates that data is being sent through the channel."""
SSH_MSG_CHANNEL_CLOSE = 87
"""Closes the channel."""
class SSHUserCredentials(DataPacket):
"""Hold Username and Password in SSH Packets."""
username: str
"""Username for login"""
password: str
"""Password for login"""
class SSHPacket(DataPacket):
"""Represents an SSHPacket."""
transport_message: SSHTransportMessage
"""Message Transport Type"""
connection_message: SSHConnectionMessage
"""Message Connection Status"""
user_account: Optional[SSHUserCredentials] = None
"""User Account Credentials if passed"""
connection_request_uuid: Optional[str] = None
"""Connection Request UUID used when establishing a remote connection"""
connection_uuid: Optional[str] = None
"""Connection UUID used when validating a remote connection"""
ssh_output: Optional[RequestResponse] = None
"""RequestResponse from Request Manager"""
ssh_command: Optional[list] = None
"""Request String"""

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
@@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel):
if self.client and self.is_active:
self.client._disconnect(self.connection_id) # noqa
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
def __repr__(self) -> str:
return str(self)
class DatabaseClient(Application, identifier="DatabaseClient"):
"""
@@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""Connection ID to the Database Server."""
client_connections: Dict[str, DatabaseClientConnection] = {}
"""Keep track of active connections to Database Server."""
_client_connection_requests: Dict[str, Optional[str]] = {}
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
"""Dictionary of connection requests to Database Server."""
connected: bool = False
"""Boolean Value for whether connected to DB Server."""
@@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
return False
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _check_client_connection(self, connection_id: str) -> bool:
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
@@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
valid_connection = self._check_client_connection(connection_id=connection_request_id)
if valid_connection:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
database_client_connection = self._client_connection_requests.pop(connection_request_id)
self.sys_log.info(
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
f"Connection Request ID was {connection_request_id}."
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
if isinstance(database_client_connection, DatabaseClientConnection):
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
f"Using connection id {database_client_connection}"
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
self._last_connection_successful = False
return None
else:
self.sys_log.warning(
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
f"Connection Request ID was {connection_request_id}."
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
f"due to unknown client-side connection request id"
)
self._last_connection_successful = False
return None
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
@@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""
if not self._can_perform_action():
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
self.sys_log.info(
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
)
return self._connect(
server_ip_address=self.server_ip_address,
password=self.server_password,

View File

@@ -191,12 +191,16 @@ class DatabaseService(Service):
:return: Response to connection request containing success info.
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
status_code = 500 # Default internal server error
connection_id = None
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
f"capacity."
)
if self.health_state_actual in [
SoftwareHealthState.GOOD,
SoftwareHealthState.FIXING,
@@ -208,12 +212,16 @@ class DatabaseService(Service):
# try to create connection
if not self.add_connection(connection_id=connection_id, session_id=session_id):
status_code = 500
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
else:
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
f"returning status code 500"
)
else:
status_code = 401 # Unauthorised
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
f"(incorrect password), returning status code 401"
)
else:
status_code = 404 # service not found
return {

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,558 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
SSHTransportMessage,
SSHUserCredentials,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
class TerminalClientConnection(BaseModel):
"""
TerminalClientConnection Class.
This class is used to record current User Connections to the Terminal class.
"""
parent_terminal: Terminal
"""The parent Node that this connection was created on."""
ssh_session_id: str = None
"""Session ID that connection is linked to, used for sending commands via session manager."""
connection_uuid: str = None
"""Connection UUID"""
connection_request_id: str = None
"""Connection request ID"""
time: datetime = None
"""Timestamp connection was created."""
ip_address: IPv4Address
"""Source IP of Connection"""
is_active: bool = True
"""Flag to state whether the connection is active or not"""
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id: '{self.connection_uuid}, ip_address: {self.ip_address}')"
def __repr__(self) -> str:
return self.__str__()
def __getitem__(self, key: Any) -> Any:
return getattr(self, key)
@property
def client(self) -> Optional[Terminal]:
"""The Terminal that holds this connection."""
return self.parent_terminal
def disconnect(self) -> bool:
"""Disconnect the session."""
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
@abstractmethod
def execute(self, command: Any) -> bool:
"""Execute a given command."""
pass
class LocalTerminalConnection(TerminalClientConnection):
"""
LocalTerminalConnectionClass.
This class represents a local terminal when connected.
"""
ip_address: str = "Local Connection"
def execute(self, command: Any) -> Optional[RequestResponse]:
"""Execute a given command on local Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return None
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return None
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
class RemoteTerminalConnection(TerminalClientConnection):
"""
RemoteTerminalConnection Class.
This class acts as broker between the terminal and remote.
"""
def execute(self, command: Any) -> bool:
"""Execute a given command on the remote Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return False
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return False
# Send command to remote terminal to process.
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload: SSHPacket = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=self.connection_request_id,
connection_uuid=self.connection_uuid,
ssh_command=command,
)
return self.parent_terminal.send(payload=payload, session_id=self.ssh_session_id)
class Terminal(Service):
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
"""Dictionary of connect requests made to remote nodes."""
def __init__(self, **kwargs):
kwargs["name"] = "Terminal"
kwargs["port"] = Port.SSH
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
return state
def show(self, markdown: bool = False):
"""
Display the remote connections to this terminal instance in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
self.show_connections(markdown=markdown)
def _init_request_manager(self) -> RequestManager:
"""Initialise Request manager."""
rm = super()._init_request_manager()
rm.add_request(
"send",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())),
)
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._process_local_login(username=request[0], password=request[1])
if login:
return RequestResponse(
status="success",
data={
"ip_address": login.ip_address,
},
)
else:
return RequestResponse(status="failure", data={"reason": "Invalid login credentials"})
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
if login:
return RequestResponse(
status="success",
data={
"ip_address": login.ip_address,
},
)
else:
return RequestResponse(status="failure", data={})
def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
command: str = request[0]
ip_address: IPv4Address = IPv4Address(request[1])
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
if remote_connection:
outcome = remote_connection.execute(command)
if outcome:
return RequestResponse(
status="success",
data={},
)
else:
return RequestResponse(
status="failure",
data={},
)
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from connection."""
connection_uuid = request[0]
self.parent.user_session_manager.local_logout(connection_uuid)
self._disconnect(connection_uuid)
return RequestResponse(status="success", data={})
rm.add_request(
"Login",
request_type=RequestType(func=_login),
)
rm.add_request(
"Remote Login",
request_type=RequestType(func=_remote_login),
)
rm.add_request(
"Execute",
request_type=RequestType(func=remote_execute_request),
)
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
return rm
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
return self.parent.apply_request(command)
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
"""Find Remote Terminal Connection from a given IP."""
for connection in self._connections.values():
if connection.ip_address == ip_address:
return connection
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
"""Create a new connection object and amend to list of active connections.
:param connection_uuid: Connection ID of the new local connection
:param session_id: Session ID of the new local connection
:return: TerminalClientConnection object
"""
new_connection = LocalTerminalConnection(
parent_terminal=self,
connection_uuid=connection_uuid,
ssh_session_id=session_id,
time=datetime.now(),
)
self._connections[connection_uuid] = new_connection
self._client_connection_requests[connection_uuid] = new_connection
return new_connection
def login(
self, username: str, password: str, ip_address: Optional[IPv4Address] = None
) -> Optional[TerminalClientConnection]:
"""Login to the terminal. Will attempt a remote login if ip_address is given, else local.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt. If None, login is assumed local.
:type: ip_address: Optional[IPv4Address]
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"{self.name}: Cannot login as service is not running.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
if ip_address:
# Assuming that if IP is passed we are connecting to remote
return self._send_remote_login(
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
)
else:
return self._process_local_login(username=username, password=password)
def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]:
"""Local session login to terminal.
:param username: Username for login.
:param password: Password for login.
:return: boolean, True if successful, else False
"""
# TODO: Un-comment this when UserSessionManager is merged.
connection_uuid = self.parent.user_session_manager.local_login(username=username, password=password)
if connection_uuid:
self.sys_log.info(f"{self.name}: Login request authorised, connection uuid: {connection_uuid}")
# Add new local session to list of connections and return
return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection")
else:
self.sys_log.warning(f"{self.name}: Login failed, incorrect Username or Password")
return None
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return connection_id in self._client_connection_requests
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return connection_id in self._connections
def _send_remote_login(
self,
username: str,
password: str,
ip_address: IPv4Address,
connection_request_id: str,
is_reattempt: bool = False,
) -> Optional[RemoteTerminalConnection]:
"""Send a remote login attempt and connect to Node.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt.
:type: ip_address: IPv4Address
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: is_reattempt: True if the request has been reattempted. Default False.
:type: is_reattempt: Optional[bool]
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
"""
self.sys_log.info(
f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}"
)
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
remote_terminal_connection = self._client_connection_requests.pop(connection_request_id)
if isinstance(remote_terminal_connection, RemoteTerminalConnection):
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
return remote_terminal_connection
else:
self.sys_log.warning(f"{self.name}: Connection request {connection_request_id} declined")
return None
else:
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
return None
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
payload_contents = {
"type": "login_request",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
}
payload: SSHPacket = SSHPacket(
payload=payload_contents,
transport_message=transport_message,
connection_message=connection_message,
user_account=user_details,
connection_request_uuid=connection_request_id,
)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=ip_address, dest_port=self.port
)
return self._send_remote_login(
username=username,
password=password,
ip_address=ip_address,
is_reattempt=True,
connection_request_id=connection_request_id,
)
def _create_remote_connection(
self, connection_id: str, connection_request_id: str, session_id: str, source_ip: str
) -> None:
"""Create a new TerminalClientConnection Object.
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: session_id: Session ID of connection.
:type: session_id: str
"""
client_connection = RemoteTerminalConnection(
parent_terminal=self,
ssh_session_id=session_id,
connection_uuid=connection_id,
ip_address=source_ip,
connection_request_id=connection_request_id,
time=datetime.now(),
)
self._connections[connection_id] = client_connection
self._client_connection_requests[connection_request_id] = client_connection
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
source_ip = kwargs["frame"].ip.src_ip_address
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
if isinstance(payload, SSHPacket):
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate & add connection
# TODO: uncomment this as part of 2781
username = payload.user_account.username
password = payload.user_account.password
connection_id = self.parent.user_session_manager.remote_login(
username=username, password=password, remote_ip_address=source_ip
)
if connection_id:
connection_request_id = payload.connection_request_uuid
self.sys_log.info(f"{self.name}: Connection authorised, session_id: {session_id}")
self._create_remote_connection(
connection_id=connection_id,
connection_request_id=connection_request_id,
session_id=session_id,
source_ip=source_ip,
)
transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload_contents = {
"type": "login_success",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
"connection_id": connection_id,
}
payload: SSHPacket = SSHPacket(
payload=payload_contents,
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=connection_request_id,
connection_uuid=connection_id,
)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_port=self.port, session_id=session_id
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
self.sys_log.info(f"{self.name}: Login Successful")
self._create_remote_connection(
connection_id=payload.connection_uuid,
connection_request_id=payload.connection_request_uuid,
session_id=session_id,
source_ip=source_ip,
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
self.sys_log.info(f"{self.name}: Received command to execute")
command = payload.ssh_command
valid_connection = self._check_client_connection(payload.connection_uuid)
if valid_connection:
self.execute(command)
return True
else:
self.sys_log.error(
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
)
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":
connection_id = payload["connection_id"]
valid_id = self._check_client_connection(connection_id)
if valid_id:
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.")
self._disconnect(payload["connection_id"])
self.parent.user_session_manager.remote_logout(remote_session_id=connection_id)
else:
self.sys_log.error(f"{self.name}: No Active connection held for received connection ID.")
if payload["type"] == "user_timeout":
connection_id = payload["connection_id"]
valid_id = self._check_client_connection(connection_id)
if valid_id:
connection = self._connections.pop(connection_id)
connection.is_active = False
self.sys_log.info(f"{self.name}: Connection {connection_id} disconnected due to inactivity.")
else:
self.sys_log.error(f"{self.name}: Connection {connection_id} is invalid.")
return True
def _disconnect(self, connection_uuid: str) -> bool:
"""Disconnect connection.
:param connection_uuid: Connection ID that we want to disconnect.
:return True if successful, False otherwise.
"""
if not self._connections:
self.sys_log.warning(f"{self.name}: No remote connection present")
return False
connection = self._connections.pop(connection_uuid)
connection.is_active = False
if isinstance(connection, RemoteTerminalConnection):
# Send disconnect command via software manager
session_id = connection.ssh_session_id
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_uuid},
dest_port=self.port,
session_id=session_id,
)
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
return True
elif isinstance(connection, LocalTerminalConnection):
self.parent.user_session_manager.local_logout()
return True
def send(
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
) -> bool:
"""
Send a payload out from the Terminal.
:param payload: The payload to be sent.
:param dest_up_address: The IP address of the payload destination.
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"{self.name}: Cannot send commands when Operating state is {self.operating_state}!")
return False
self.sys_log.debug(f"{self.name}: Sending payload: {payload}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
)

View File

@@ -313,7 +313,7 @@ class IOSoftware(Software):
# if over or at capacity, set to overwhelmed
if len(self._connections) >= self.max_sessions:
self.set_health_state(SoftwareHealthState.OVERWHELMED)
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.")
return False
else:
# if service was previously overwhelmed, set to good because there is enough space for connections
@@ -330,11 +330,11 @@ class IOSoftware(Software):
"ip_address": session_details.with_ip_address if session_details else None,
"time": datetime.now(),
}
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised")
return True
# connection with given id already exists
self.sys_log.warning(
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
f"{self.name}: Connection request ({connection_id}) declined. Connection already exists."
)
return False