From 4796cee2dc7882a592801ece308d6eeebc82ba60 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 26 Jun 2024 16:51:30 +0100 Subject: [PATCH 01/95] #2676: Put global variables in dataclass --- src/primaite/simulator/network/nmne.py | 81 ++++++++++++++------------ 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index 5c0c657b..d6f1763f 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -1,48 +1,55 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Dict, Final, List - -CAPTURE_NMNE: bool = True -"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True.""" - -NMNE_CAPTURE_KEYWORDS: List[str] = [] -"""List of keywords to identify malicious network events.""" - -# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically -CAPTURE_BY_DIRECTION: Final[bool] = True -"""Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" -CAPTURE_BY_IP_ADDRESS: Final[bool] = False -"""Flag to determine if captures should be organized by source or destination IP address.""" -CAPTURE_BY_PROTOCOL: Final[bool] = False -"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP).""" -CAPTURE_BY_PORT: Final[bool] = False -"""Flag to determine if captures should be organized by source or destination port.""" -CAPTURE_BY_KEYWORD: Final[bool] = False -"""Flag to determine if captures should be filtered and categorised based on specific keywords.""" +from dataclasses import dataclass, field +from typing import Dict, List -def set_nmne_config(nmne_config: Dict): +@dataclass +class nmne_data: + """Store all the information to perform NMNE operations.""" + + capture_nmne: bool = True + """Indicates whether Malicious Network Events (MNEs) should be captured.""" + nmne_capture_keywords: List[str] = field(default_factory=list) + """List of keywords to identify malicious network events.""" + capture_by_direction: bool = True + """Captures should be organized by traffic direction (inbound/outbound).""" + capture_by_ip_address: bool = False + """Captures should be organized by source or destination IP address.""" + capture_by_protocol: bool = False + """Captures should be organized by network protocol (e.g., TCP, UDP).""" + capture_by_port: bool = False + """Captures should be organized by source or destination port.""" + capture_by_keyword: bool = False + """Captures should be filtered and categorised based on specific keywords.""" + + +def set_nmne_config(nmne_config: Dict) -> nmne_data: """ - Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary. + Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided + dictionary. - This function updates global settings related to NMNE capture, including whether to capture NMNEs and what - keywords to use for identifying NMNEs. + This function updates global settings related to NMNE capture, including whether to capture + NMNEs and what keywords to use for identifying NMNEs. - The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary, - and maintains type integrity by checking the types of the provided values. + The function ensures that the settings are updated only if they are provided in the + `nmne_config` dictionary, and maintains type integrity by checking the types of the provided + values. - :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include: - "capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings) - to specify keywords for NMNE identification. + :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys + include: + "capture_nmne" (bool) to indicate whether NMNEs should be captured; + "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. + :rvar dataclass with data read from config file. """ - global NMNE_CAPTURE_KEYWORDS - global CAPTURE_NMNE - + nmne_capture_keywords = [] # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect - CAPTURE_NMNE = nmne_config.get("capture_nmne", False) - if not isinstance(CAPTURE_NMNE, bool): - CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean + capture_nmne = nmne_config.get("capture_nmne", False) + if not isinstance(capture_nmne, bool): + capture_nmne = True # Revert to default True if the provided value is not a boolean # Update the NMNE capture keywords, appending new keywords if provided - NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", []) - if not isinstance(NMNE_CAPTURE_KEYWORDS, list): - NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list + nmne_capture_keywords += nmne_config.get("nmne_capture_keywords", []) + if not isinstance(nmne_capture_keywords, list): + nmne_capture_keywords = [] # Reset to empty list if the provided value is not a list + + return nmne_data(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) From dbc1d73c34f31a3d01e38471ddaa88834a7dfe63 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 2 Jul 2024 11:15:31 +0100 Subject: [PATCH 02/95] #2676: Update naming of NMNE class --- src/primaite/game/game.py | 13 ++++++++++--- src/primaite/simulator/network/nmne.py | 11 +++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..cc559b4d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -23,7 +23,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.network.nmne import store_nmne_config, NmneData from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -113,6 +113,9 @@ class PrimaiteGame: self._reward_calculation_order: List[str] = [name for name in self.agents] """Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards.""" + self.nmne_config: NmneData = None + """ Config data from Number of Malicious Network Events.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -496,10 +499,11 @@ class PrimaiteGame: # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. game.setup_reward_sharing() - # Set the NMNE capture config - set_nmne_config(network_config.get("nmne_config", {})) game.update_agents(game.get_sim_state()) + # Set the NMNE capture config + game.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) + return game def setup_reward_sharing(self): @@ -539,3 +543,6 @@ class PrimaiteGame: # sort the agents so the rewards that depend on other rewards are always evaluated later self._reward_calculation_order = topological_sort(graph) + + def get_nmne_config(self) -> NmneData: + return self.nmne_config diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index d6f1763f..947f27ac 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -4,7 +4,7 @@ from typing import Dict, List @dataclass -class nmne_data: +class NmneData: """Store all the information to perform NMNE operations.""" capture_nmne: bool = True @@ -23,10 +23,9 @@ class nmne_data: """Captures should be filtered and categorised based on specific keywords.""" -def set_nmne_config(nmne_config: Dict) -> nmne_data: +def store_nmne_config(nmne_config: Dict) -> NmneData: """ - Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided - dictionary. + Store configuration for capturing Malicious Network Events (MNEs). This function updates global settings related to NMNE capture, including whether to capture NMNEs and what keywords to use for identifying NMNEs. @@ -41,7 +40,7 @@ def set_nmne_config(nmne_config: Dict) -> nmne_data: "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. :rvar dataclass with data read from config file. """ - nmne_capture_keywords = [] + nmne_capture_keywords: List[str] = [] # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect capture_nmne = nmne_config.get("capture_nmne", False) if not isinstance(capture_nmne, bool): @@ -52,4 +51,4 @@ def set_nmne_config(nmne_config: Dict) -> nmne_data: if not isinstance(nmne_capture_keywords, list): nmne_capture_keywords = [] # Reset to empty list if the provided value is not a list - return nmne_data(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) + return NmneData(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) From bd05f4d4e81b1c5038dc45fc74916de2e53f6fe4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 2 Jul 2024 15:02:59 +0100 Subject: [PATCH 03/95] #2711 - Initial commit of Terminal Service Skeleton framework. Added in a placeholder SSHPacket class. Currently, this allows the Terminal 'service' to be installed onto a HostNode class, and Port 22 - SSH to be visible when using .show(). Functionality and testing still to be completed --- .../system/services/terminal.rst | 26 +++ src/primaite/game/game.py | 2 + .../network/hardware/nodes/host/host_node.py | 2 + .../simulator/network/protocols/ssh.py | 71 +++++++ .../system/services/terminal/__init__.py | 1 + .../system/services/terminal/terminal.py | 190 ++++++++++++++++++ 6 files changed, 292 insertions(+) create mode 100644 docs/source/simulation_components/system/services/terminal.rst create mode 100644 src/primaite/simulator/network/protocols/ssh.py create mode 100644 src/primaite/simulator/system/services/terminal/__init__.py create mode 100644 src/primaite/simulator/system/services/terminal/terminal.py diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst new file mode 100644 index 00000000..bf8072e8 --- /dev/null +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -0,0 +1,26 @@ +.. only:: comment + + © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +.. _Terminal: + +Terminal +######## + +The ``Terminal`` provides a generic terminal simulation, by extending the base Service class + +Key capabilities +================ + + - Authenticates User connection by maintaining an active User account. + - Ensures packets are matched to an existing session + - Simulates common Terminal commands + - Leverages the Service base class for install/uninstall, status tracking etc. + + +Usage +===== + + - Install on a node via the ``SoftwareManager`` to start the Terminal + - Terminal Clients connect, execute commands and disconnect. + - Service runs on SSH port 22 by default. diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..908eecbb 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -38,6 +38,7 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) @@ -60,6 +61,7 @@ SERVICE_TYPES_MAPPING = { "FTPServer": FTPServer, "NTPClient": NTPClient, "NTPServer": NTPServer, + "Terminal": Terminal, } """List of available services that can be installed on nodes in the PrimAITE Simulation.""" diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index fdb28339..5848ade4 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -15,6 +15,7 @@ from primaite.simulator.system.services.arp.arp import ARP, ARPPacket from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.utils.validators import IPV4Address _LOGGER = getLogger(__name__) @@ -306,6 +307,7 @@ class HostNode(Node): "NTPClient": NTPClient, "WebBrowser": WebBrowser, "NMAP": NMAP, + "Terminal": Terminal, } """List of system software that is automatically installed on nodes.""" diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py new file mode 100644 index 00000000..448f0fec --- /dev/null +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -0,0 +1,71 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + +from enum import IntEnum +from typing import Dict, Optional + +from primaite.interface.request import RequestResponse +from primaite.simulator.network.protocols.packet import DataPacket + +# TODO: Elaborate / Confirm / Validate - See 2709. +# Placeholder implementation for Terminal Class implementation. + + +class SSHTransportMessage(IntEnum): + """ + Enum list of Transport layer messages that can be handled by the simulation. + + Each msg value is equivalent to the real-world. + """ + + SSH_MSG_USERAUTH_REQUEST = 50 + """Requests User Authentication.""" + + SSH_MSG_USERAUTH_FAILURE = 51 + """Indicates User Authentication failed.""" + + SSH_MSG_USERAUTH_SUCCESS = 52 + """Indicates User Authentication failed was successful.""" + + SSH_MSG_SERVICE_REQUEST = 24 + """Requests a service - such as executing a command.""" + + # These two msgs are invented for primAITE however are modelled on reality + + SSH_MSG_SERVICE_FAILED = 25 + """Indicates that the requested service failed.""" + + SSH_MSG_SERVICE_SUCCESS = 26 + """Indicates that the requested service was successful.""" + + +class SSHConnectionMessage(IntEnum): + """Int Enum list of all SSH's connection protocol messages that can be handled by the simulation.""" + + SSH_MSG_CHANNEL_OPEN = 80 + """Requests an open channel - Used in combination with SSH_MSG_USERAUTH_REQUEST.""" + + SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 81 + """Confirms an open channel.""" + + SSH_MSG_CHANNEL_OPEN_FAILED = 82 + """Indicates that channel opening failure.""" + + SSH_MSG_CHANNEL_DATA = 84 + """Indicates that data is being sent through the channel.""" + + SSH_MSG_CHANNEL_CLOSE = 87 + """Closes the channel.""" + + +class SSHPacket(DataPacket): + """Represents an SSHPacket.""" + + transport_message: SSHTransportMessage + + connection_message: SSHConnectionMessage + + ssh_command: Optional[any] = None # This is the request string + + ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse + + user_account: Optional[Dict] = None # The user account we will use to login if we do not have a current connection. diff --git a/src/primaite/simulator/system/services/terminal/__init__.py b/src/primaite/simulator/system/services/terminal/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/src/primaite/simulator/system/services/terminal/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py new file mode 100644 index 00000000..d86d21c6 --- /dev/null +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -0,0 +1,190 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + +from ipaddress import IPv4Address, IPv4Network +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator +from primaite.simulator.network.protocols.icmp import ICMPPacket + +# from primaite.simulator.network.protocols.ssh import SSHPacket, SSHTransportMessage, SSHConnectionMessage +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import Service, ServiceOperatingState + + +class Terminal(Service): + """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" + + user_account: Optional[str] = None + "The User Account used for login" + + connected: bool = False + "Boolean Value for whether connected" + + connection_uuid: Optional[str] = None + "Uuid for connection requests" + + def __init__(self, **kwargs): + kwargs["name"] = "Terminal" + kwargs["port"] = Port.SSH + kwargs["protocol"] = IPProtocol.TCP + + super().__init__(**kwargs) + self.operating_state = ServiceOperatingState.RUNNING + + class _LoginValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if we have valid login credentials. + + This should ensure that no actions are resolved without valid user credentials. + """ + + terminal: Terminal + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the login credentials are valid.""" + pass + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator.""" + return ( + f"Cannot perform request on Terminal '{self.terminal.hostname}' because login credentials are invalid" + ) + + def _validate_login(self) -> bool: + """Validate login credentials when receiving commands.""" + # TODO: Implement + return True + + def receive_payload_from_software_manager( + self, + payload: Any, + dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, + src_port: Optional[Port] = None, + dst_port: Optional[Port] = None, + session_id: Optional[str] = None, + ip_protocol: IPProtocol = IPProtocol.TCP, + icmp_packet: Optional[ICMPPacket] = None, + connection_id: Optional[str] = None, + ) -> Union[Any, None]: + """Receive Software Manager Payload.""" + self._validate_login() + + def _init_request_manager(self) -> RequestManager: + """Initialise Request manager.""" + # _login_is_valid = Terminal._LoginValidator(terminal=self) + rm = super()._init_request_manager() + + return rm + + def send( + self, + payload: Any, + dest_ip_address: Optional[IPv4Address] = None, + session_id: Optional[str] = None, + ) -> bool: + """Send Request to Software Manager.""" + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) + + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. + + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + state = super().describe_state() + # TBD + state.update({"hostname": self.hostname}) + return state + + def execute(self, command: Any, request: Any) -> Optional[RequestResponse]: + """Execute Command.""" + # Returning the request to the request manager. + if self._validate_login(): + return self.apply_request(request) + else: + self.sys_log.error("Invalid login credentials provided.") + return None + + def apply_request(self, request: List[str | int | float | Dict], context: Dict | None = None) -> RequestResponse: + """Apply Temrinal Request.""" + return super().apply_request(request, context) + + def login(self, dest_ip_address: IPv4Address) -> bool: + """ + Perform an initial login request. + + If this fails, raises an error. + """ + # TODO: This will need elaborating when user accounts are implemented + self.sys_log.info("Attempting Login") + self._ssh_process_login(self, dest_ip_address=dest_ip_address, user_account=self.user_account) + + def _generate_connection_id(self) -> str: + """Generate a unique connection ID.""" + return str(uuid4()) + + # %% + + # def _ssh_process_login(self, user_account: dict, **kwargs) -> SSHPacket: + # """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" + # # we assume that the login fails unless we meet all the criteria. + # transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE + # connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED + # # operating state validation here(if overwhelmed) + + # # Hard coded at current - replace with another method to handle local accounts. + # if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded + # connection_id = self._generate_connection_id() + # if not self.add_connection(self, connection_id="ssh_connection", session_id=self.session_id): + # self.sys_log.warning(f"{self.name}: Connect request for {self.src_ip} declined. + # Service is at capacity.") + # ... + # else: + # self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + # transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + # connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION + + # payload: SSHPacket = SSHPacket(transport_message = transport_message, connection_message = connection_message) + # return payload + + # %% + # Copy + Paste from Terminal Wiki + + # def ssh_remote_login(self, dest_ip_address = IPv4Address, user_account: Optional[dict] = None) -> bool: + # if user_account: + # # Setting default creds (Best to use this until we have more clarification on the specifics of user accounts) + # self.user_account = {self.user_name:"placeholder", self.password:"placeholder"} + + # # Implement SSHPacket class + # payload: SSHPacket = SSHPacket(transport_message= SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST, + # connection_message= SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + # user_account=user_account) + # if self.send(payload=payload,dest_ip_address=dest_ip_address): + # if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: + # self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") + # # Need to confirm if self.uuid is correct. + # self.add_connection(self, connection_id=self.uuid, session_id=self.session_id) + # return True + # else: + # self.sys_log.error("Payload type incorrect, Login Failed") + # return False + # else: + # self.sys_log.error("Incorrect credentials provided. Login Failed.") + # return False + # %% + + def connect(self, **kwargs): + """Send connect request.""" + self._connect(self, **kwargs) + + def _connect(self): + """Do something.""" + pass From ebf6e7a90ebf8b713ac0ae3dc958896a80ed0bea Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 2 Jul 2024 16:47:39 +0100 Subject: [PATCH 04/95] #2711 - Added in remote_login and process_login methods. Minor updates to make pydantic happy. Starting to flesh out functionality of Terminal Service in more detail --- .../simulator/network/protocols/ssh.py | 6 +- .../system/services/terminal/terminal.py | 139 +++++++++--------- 2 files changed, 73 insertions(+), 72 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 448f0fec..7be81982 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -60,11 +60,11 @@ class SSHConnectionMessage(IntEnum): class SSHPacket(DataPacket): """Represents an SSHPacket.""" - transport_message: SSHTransportMessage + transport_message: SSHTransportMessage = None - connection_message: SSHConnectionMessage + connection_message: SSHConnectionMessage = None - ssh_command: Optional[any] = None # This is the request string + ssh_command: Optional[str] = None # This is the request string ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index d86d21c6..e1964f78 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -8,8 +8,7 @@ from uuid import uuid4 from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator from primaite.simulator.network.protocols.icmp import ICMPPacket - -# from primaite.simulator.network.protocols.ssh import SSHPacket, SSHTransportMessage, SSHConnectionMessage +from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import Service, ServiceOperatingState @@ -21,19 +20,22 @@ class Terminal(Service): user_account: Optional[str] = None "The User Account used for login" - connected: bool = False + is_connected: bool = False "Boolean Value for whether connected" connection_uuid: Optional[str] = None "Uuid for connection requests" + operating_state: ServiceOperatingState = ServiceOperatingState.INSTALLING + """Service Operating State""" # Install at start ??? Maybe ??? + def __init__(self, **kwargs): kwargs["name"] = "Terminal" kwargs["port"] = Port.SSH kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self.operating_state = ServiceOperatingState.RUNNING + # self.operating_state = ServiceOperatingState.RUNNING class _LoginValidator(RequestPermissionValidator): """ @@ -46,34 +48,22 @@ class Terminal(Service): def __call__(self, request: RequestFormat, context: Dict) -> bool: """Return whether the login credentials are valid.""" - pass + # TODO: Expand & Implement logic when we have User Accounts. + if self.terminal.is_connected: + return True + else: + self.terminal.sys_log.error("terminal is not logged in.") @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return ( - f"Cannot perform request on Terminal '{self.terminal.hostname}' because login credentials are invalid" - ) + return f"Cannot perform request on Terminal '{self.terminal.name}' because login credentials are invalid" def _validate_login(self) -> bool: """Validate login credentials when receiving commands.""" # TODO: Implement return True - def receive_payload_from_software_manager( - self, - payload: Any, - dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, - session_id: Optional[str] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, - icmp_packet: Optional[ICMPPacket] = None, - connection_id: Optional[str] = None, - ) -> Union[Any, None]: - """Receive Software Manager Payload.""" - self._validate_login() - def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" # _login_is_valid = Terminal._LoginValidator(terminal=self) @@ -101,7 +91,7 @@ class Terminal(Service): """ state = super().describe_state() # TBD - state.update({"hostname": self.hostname}) + state.update({"hostname": self.name}) return state def execute(self, command: Any, request: Any) -> Optional[RequestResponse]: @@ -133,58 +123,69 @@ class Terminal(Service): # %% - # def _ssh_process_login(self, user_account: dict, **kwargs) -> SSHPacket: - # """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" - # # we assume that the login fails unless we meet all the criteria. - # transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE - # connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED - # # operating state validation here(if overwhelmed) + def _ssh_process_login(self, user_account: dict, **kwargs) -> SSHPacket: + """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" + # we assume that the login fails unless we meet all the criteria. + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED + # operating state validation here(if overwhelmed) - # # Hard coded at current - replace with another method to handle local accounts. - # if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded - # connection_id = self._generate_connection_id() - # if not self.add_connection(self, connection_id="ssh_connection", session_id=self.session_id): - # self.sys_log.warning(f"{self.name}: Connect request for {self.src_ip} declined. - # Service is at capacity.") - # ... - # else: - # self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") - # transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS - # connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION + # Hard coded at current - replace with another method to handle local accounts. + if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded + connection_id = self._generate_connection_id() + if not self.add_connection(self, connection_id="ssh_connection", session_id=self.session_id): + self.sys_log.warning( + f"{self.name}: Connect request for {self.src_ip} declined. Service is at capacity." + ) + else: + self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION + self.is_connected = True - # payload: SSHPacket = SSHPacket(transport_message = transport_message, connection_message = connection_message) - # return payload + payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) + return payload # %% # Copy + Paste from Terminal Wiki - # def ssh_remote_login(self, dest_ip_address = IPv4Address, user_account: Optional[dict] = None) -> bool: - # if user_account: - # # Setting default creds (Best to use this until we have more clarification on the specifics of user accounts) - # self.user_account = {self.user_name:"placeholder", self.password:"placeholder"} + def ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: + """Remote login to terminal via SSH.""" + if user_account: + # Setting default creds (Best to use this until we have more clarification around user accounts) + self.user_account = {self.user_name: "placeholder", self.password: "placeholder"} + + # Implement SSHPacket class + payload: SSHPacket = SSHPacket( + transport_message=SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + user_account=user_account, + ) + if self.send(payload=payload, dest_ip_address=dest_ip_address): + if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: + self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") + # Need to confirm if self.uuid is correct. + self.add_connection(self, connection_id=self.uuid, session_id=self.session_id) + return True + else: + self.sys_log.error("Payload type incorrect, Login Failed") + return False + else: + self.sys_log.error("Incorrect credentials provided. Login Failed.") + return False - # # Implement SSHPacket class - # payload: SSHPacket = SSHPacket(transport_message= SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST, - # connection_message= SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, - # user_account=user_account) - # if self.send(payload=payload,dest_ip_address=dest_ip_address): - # if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - # self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") - # # Need to confirm if self.uuid is correct. - # self.add_connection(self, connection_id=self.uuid, session_id=self.session_id) - # return True - # else: - # self.sys_log.error("Payload type incorrect, Login Failed") - # return False - # else: - # self.sys_log.error("Incorrect credentials provided. Login Failed.") - # return False # %% - def connect(self, **kwargs): - """Send connect request.""" - self._connect(self, **kwargs) - - def _connect(self): - """Do something.""" - pass + def receive_payload_from_software_manager( + self, + payload: Any, + dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, + src_port: Optional[Port] = None, + dst_port: Optional[Port] = None, + session_id: Optional[str] = None, + ip_protocol: IPProtocol = IPProtocol.TCP, + icmp_packet: Optional[ICMPPacket] = None, + connection_id: Optional[str] = None, + ) -> Union[Any, None]: + """Receive Software Manager Payload.""" + self._validate_login() From 47df2aa56940c26047c4e2b6672867e9016c8b1f Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 4 Jul 2024 15:41:13 +0100 Subject: [PATCH 05/95] #2676: Store NMNE config data in class variable. --- src/primaite/game/game.py | 13 ++---- .../simulator/network/hardware/base.py | 42 +++++++------------ 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cc559b4d..9636bd23 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -15,7 +15,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort -from primaite.simulator.network.hardware.base import NodeOperatingState +from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server @@ -23,7 +23,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import store_nmne_config, NmneData +from primaite.simulator.network.nmne import NmneData, store_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -239,6 +239,8 @@ class PrimaiteGame: nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) + # Set the NMNE capture config + NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) for node_cfg in nodes_cfg: n_type = node_cfg["type"] @@ -500,10 +502,6 @@ class PrimaiteGame: game.setup_reward_sharing() game.update_agents(game.get_sim_state()) - - # Set the NMNE capture config - game.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) - return game def setup_reward_sharing(self): @@ -543,6 +541,3 @@ class PrimaiteGame: # sort the agents so the rewards that depend on other rewards are always evaluated later self._reward_calculation_order = topological_sort(graph) - - def get_nmne_config(self) -> NmneData: - return self.nmne_config diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 01745215..6d753731 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,12 +6,11 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field -import primaite.simulator.network.nmne from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse @@ -20,15 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.nmne import ( - CAPTURE_BY_DIRECTION, - CAPTURE_BY_IP_ADDRESS, - CAPTURE_BY_KEYWORD, - CAPTURE_BY_PORT, - CAPTURE_BY_PROTOCOL, - CAPTURE_NMNE, - NMNE_CAPTURE_KEYWORDS, -) +from primaite.simulator.network.nmne import NmneData from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.applications.application import Application @@ -108,8 +99,8 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." - nmne: Dict = Field(default_factory=lambda: {}) - "A dict containing details of the number of malicious network events captured." + nmne_config: ClassVar[NmneData] = None + "A dataclass defining malicious network events to be captured." traffic: Dict = Field(default_factory=lambda: {}) "A dict containing details of the inbound and outbound traffic by port and protocol." @@ -117,7 +108,6 @@ class NetworkInterface(SimComponent, ABC): def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" super().setup_for_episode(episode=episode) - self.nmne = {} self.traffic = {} if episode and self.pcap and SIM_OUTPUT.save_pcap_logs: self.pcap.current_episode = episode @@ -152,8 +142,8 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) - if CAPTURE_NMNE: - state.update({"nmne": {k: v for k, v in self.nmne.items()}}) + if self.nmne_config and self.nmne_config.capture_nmne: + state.update({"nmne": {self.nmne_config.__dict__}}) state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)}) return state @@ -186,7 +176,7 @@ class NetworkInterface(SimComponent, ABC): :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. """ # Exit function if NMNE capturing is disabled - if not CAPTURE_NMNE: + if not (self.nmne_config and self.nmne_config.capture_nmne): return # Initialise basic frame data variables @@ -207,27 +197,27 @@ class NetworkInterface(SimComponent, ABC): frame_str = str(frame.payload) # Proceed only if any NMNE keyword is present in the frame payload - if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS): + if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords): # Start with the root of the NMNE capture structure - current_level = self.nmne + current_level = self.nmne_config # Update NMNE structure based on enabled settings - if CAPTURE_BY_DIRECTION: + if self.nmne_config.capture_by_direction: # Set or get the dictionary for the current direction current_level = current_level.setdefault("direction", {}) current_level = current_level.setdefault(direction, {}) - if CAPTURE_BY_IP_ADDRESS: + if self.nmne_config.capture_by_ip_address: # Set or get the dictionary for the current IP address current_level = current_level.setdefault("ip_address", {}) current_level = current_level.setdefault(ip_address, {}) - if CAPTURE_BY_PROTOCOL: + if self.nmne_config.capture_by_protocol: # Set or get the dictionary for the current protocol current_level = current_level.setdefault("protocol", {}) current_level = current_level.setdefault(protocol, {}) - if CAPTURE_BY_PORT: + if self.nmne_config.capture_by_port: # Set or get the dictionary for the current port current_level = current_level.setdefault("port", {}) current_level = current_level.setdefault(port, {}) @@ -236,8 +226,8 @@ class NetworkInterface(SimComponent, ABC): keyword_level = current_level.setdefault("keywords", {}) # Increment the count for detected keywords in the payload - if CAPTURE_BY_KEYWORD: - for keyword in NMNE_CAPTURE_KEYWORDS: + if self.nmne_config.capture_by_keyword: + for keyword in self.nmne_config.nmne_capture_keywords: if keyword in frame_str: # Update the count for each keyword found keyword_level[keyword] = keyword_level.get(keyword, 0) + 1 @@ -1067,7 +1057,7 @@ class Node(SimComponent): ip_address, network_interface.speed, "Enabled" if network_interface.enabled else "Disabled", - network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled", + network_interface.nmne if self.nmne_config.capture_nmne else "Disabled", ] ) print(table) From 3867ec40c9c571d719ff6fe818c505721a99cbc0 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 4 Jul 2024 17:05:00 +0100 Subject: [PATCH 06/95] #2676: Fix nmne_config dict conversion --- src/primaite/simulator/network/hardware/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 6d753731..3c52a65d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +from dataclasses import asdict import re import secrets from abc import ABC, abstractmethod @@ -143,7 +144,7 @@ class NetworkInterface(SimComponent, ABC): } ) if self.nmne_config and self.nmne_config.capture_nmne: - state.update({"nmne": {self.nmne_config.__dict__}}) + state.update({"nmne": asdict(self.nmne_config)}) state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)}) return state From 589ea2fed4e96e14365a4f92c504ad7fbf21b6c2 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 5 Jul 2024 12:19:52 +0100 Subject: [PATCH 07/95] #2676: Add local nmne dict --- src/primaite/simulator/network/hardware/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 3c52a65d..e611f9b2 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -103,12 +103,16 @@ class NetworkInterface(SimComponent, ABC): nmne_config: ClassVar[NmneData] = None "A dataclass defining malicious network events to be captured." + nmne: Dict = Field(default_factory=lambda: {}) + "A dict containing details of the number of malicious events captured." + traffic: Dict = Field(default_factory=lambda: {}) "A dict containing details of the inbound and outbound traffic by port and protocol." def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" super().setup_for_episode(episode=episode) + self.nmne = {} self.traffic = {} if episode and self.pcap and SIM_OUTPUT.save_pcap_logs: self.pcap.current_episode = episode @@ -144,7 +148,7 @@ class NetworkInterface(SimComponent, ABC): } ) if self.nmne_config and self.nmne_config.capture_nmne: - state.update({"nmne": asdict(self.nmne_config)}) + state.update({"nmne": self.nmne}) state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)}) return state @@ -200,7 +204,7 @@ class NetworkInterface(SimComponent, ABC): # Proceed only if any NMNE keyword is present in the frame payload if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords): # Start with the root of the NMNE capture structure - current_level = self.nmne_config + current_level = self.nmne # Update NMNE structure based on enabled settings if self.nmne_config.capture_by_direction: From 18ae3acf3734f36a86cf46045ef5f0ed5c51e02c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 5 Jul 2024 14:09:39 +0100 Subject: [PATCH 08/95] #2676: Update nmne tests --- src/primaite/simulator/network/hardware/base.py | 1 - .../network/test_capture_nmne.py | 16 +++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index e611f9b2..f161b2b5 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1,7 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from dataclasses import asdict import re import secrets from abc import ABC, abstractmethod diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index a8f1f245..f6e4c685 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,12 +1,14 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.network.nmne import store_nmne_config from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection -def test_capture_nmne(uc2_network): +def test_capture_nmne(uc2_network: Network): """ Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured. @@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network): } # Apply the NMNE configuration settings - set_nmne_config(nmne_config) + NIC.nmne_config = store_nmne_config(nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers assert web_server_nic.nmne == {} @@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network): assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}} -def test_describe_state_nmne(uc2_network): +def test_describe_state_nmne(uc2_network: Network): """ Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state. @@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network): } # Apply the NMNE configuration settings - set_nmne_config(nmne_config) + NIC.nmne_config = store_nmne_config(nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers web_server_nic_state = web_server_nic.describe_state() @@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network): assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}} -def test_capture_nmne_observations(uc2_network): +def test_capture_nmne_observations(uc2_network: Network): """ Tests the NICObservation class's functionality within a simulated network environment. @@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network): } # Apply the NMNE configuration settings - set_nmne_config(nmne_config) + NIC.nmne_config = store_nmne_config(nmne_config) # Define observations for the NICs of the database and web servers db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True) From 219d448adc0f7a0be2bbaa5c246af46a25cb66b4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 8 Jul 2024 07:58:10 +0100 Subject: [PATCH 09/95] #2711 - Rewrite of the majority of the terminal class after not liking how I originally did it. This takes a heavier inspiration for handling connections from the database_client/server --- .../system/services/terminal.rst | 30 +++ .../system/services/terminal/terminal.py | 254 ++++++++++-------- 2 files changed, 171 insertions(+), 113 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index bf8072e8..afa79c0a 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -24,3 +24,33 @@ Usage - Install on a node via the ``SoftwareManager`` to start the Terminal - Terminal Clients connect, execute commands and disconnect. - Service runs on SSH port 22 by default. + +Implementation +============== + +- Manages SSH commands +- Ensures User login before sending commands +- Processes SSH commands +- Returns results in a ** format. + + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.services.terminal.terminal import Terminal + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState + + client = Computer( + hostname="client", + ip_address="192.168.10.21", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + operating_state=NodeOperatingState.ON, + ) + + terminal: Terminal = client.software_manager.software.get("Terminal") diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e1964f78..5f8719ac 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -1,19 +1,57 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, List, Optional, Union +from ipaddress import IPv4Address +from typing import Dict, List, Optional from uuid import uuid4 -from primaite.interface.request import RequestFormat, RequestResponse -from primaite.simulator.core import RequestManager, RequestPermissionValidator -from primaite.simulator.network.protocols.icmp import ICMPPacket +from pydantic import BaseModel + +from primaite.interface.request import RequestResponse +from primaite.simulator.core import RequestManager +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +class TerminalClientConnection(BaseModel): + """ + TerminalClientConnection Class. + + This class is used to record current User Connections within the Terminal class. + """ + + connection_id: str + """Connection UUID.""" + + parent_node: HostNode + """The parent Node that this connection was created on.""" + + is_active: bool = True + """Flag to state whether the connection is still active or not.""" + + _dest_ip_address: IPv4Address + """Destination IP address of connection""" + + @property + def dest_ip_address(self) -> Optional[IPv4Address]: + """Destination IP Address.""" + return self._dest_ip_address + + @property + def client(self) -> Optional[Terminal]: + """The Terminal that holds this connection.""" + return self.parent_node.software_manager.software.get("Terminal") + + def disconnect(self): + """Disconnect the connection.""" + if self.client and self.is_active: + self.client._disconnect(self.connection_id) # noqa + + class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" @@ -26,59 +64,17 @@ class Terminal(Service): connection_uuid: Optional[str] = None "Uuid for connection requests" - operating_state: ServiceOperatingState = ServiceOperatingState.INSTALLING - """Service Operating State""" # Install at start ??? Maybe ??? + operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING + """Initial Operating State""" + + user_connections: Dict[str, TerminalClientConnection] = {} + """List of authenticated connected users""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" kwargs["port"] = Port.SSH kwargs["protocol"] = IPProtocol.TCP - super().__init__(**kwargs) - # self.operating_state = ServiceOperatingState.RUNNING - - class _LoginValidator(RequestPermissionValidator): - """ - When requests come in, this validator will only let them through if we have valid login credentials. - - This should ensure that no actions are resolved without valid user credentials. - """ - - terminal: Terminal - - def __call__(self, request: RequestFormat, context: Dict) -> bool: - """Return whether the login credentials are valid.""" - # TODO: Expand & Implement logic when we have User Accounts. - if self.terminal.is_connected: - return True - else: - self.terminal.sys_log.error("terminal is not logged in.") - - @property - def fail_message(self) -> str: - """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on Terminal '{self.terminal.name}' because login credentials are invalid" - - def _validate_login(self) -> bool: - """Validate login credentials when receiving commands.""" - # TODO: Implement - return True - - def _init_request_manager(self) -> RequestManager: - """Initialise Request manager.""" - # _login_is_valid = Terminal._LoginValidator(terminal=self) - rm = super()._init_request_manager() - - return rm - - def send( - self, - payload: Any, - dest_ip_address: Optional[IPv4Address] = None, - session_id: Optional[str] = None, - ) -> bool: - """Send Request to Software Manager.""" - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) def describe_state(self) -> Dict: """ @@ -90,23 +86,64 @@ class Terminal(Service): :rtype: Dict """ state = super().describe_state() - # TBD + state.update({"hostname": self.name}) return state - def execute(self, command: Any, request: Any) -> Optional[RequestResponse]: - """Execute Command.""" - # Returning the request to the request manager. - if self._validate_login(): - return self.apply_request(request) - else: - self.sys_log.error("Invalid login credentials provided.") - return None - def apply_request(self, request: List[str | int | float | Dict], context: Dict | None = None) -> RequestResponse: """Apply Temrinal Request.""" return super().apply_request(request, context) + def _init_request_manager(self) -> RequestManager: + """Initialise Request manager.""" + # TODO: Expand with a login validator? + rm = super()._init_request_manager() + return rm + + # %% Inbound + + def _generate_connection_id(self) -> str: + """Generate a unique connection ID.""" + return str(uuid4()) + + def process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: + """Process User request to login to Terminal.""" + if user_account in self.user_connections: + self.sys_log.debug("User authentication passed") + return True + else: + self._ssh_process_login(dest_ip_address=dest_ip_address, user_account=user_account) + self.process_login(dest_ip_address=dest_ip_address, user_account=user_account) + + def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: + """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" + # we assume that the login fails unless we meet all the criteria. + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED + + # Hard coded at current - replace with another method to handle local accounts. + if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded + connection_id = self._generate_connection_id() + if not self.add_connection(self, connection_id=connection_id): + self.sys_log.warning( + f"{self.name}: Connect request for {dest_ip_address} declined. Service is at capacity." + ) + return False + else: + self.sys_log.info(f"{self.name}: Connect request for ID: {connection_id} authorised") + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION + new_connection = TerminalClientConnection(connection_id=connection_id, dest_ip_address=dest_ip_address) + self.user_connections[connection_id] = new_connection + self.is_connected = True + + payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) + + self.send(payload=payload, dest_ip_address=dest_ip_address) + return True + + # %% Outbound + def login(self, dest_ip_address: IPv4Address) -> bool: """ Perform an initial login request. @@ -115,45 +152,13 @@ class Terminal(Service): """ # TODO: This will need elaborating when user accounts are implemented self.sys_log.info("Attempting Login") - self._ssh_process_login(self, dest_ip_address=dest_ip_address, user_account=self.user_account) - - def _generate_connection_id(self) -> str: - """Generate a unique connection ID.""" - return str(uuid4()) - - # %% - - def _ssh_process_login(self, user_account: dict, **kwargs) -> SSHPacket: - """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" - # we assume that the login fails unless we meet all the criteria. - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED - # operating state validation here(if overwhelmed) - - # Hard coded at current - replace with another method to handle local accounts. - if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded - connection_id = self._generate_connection_id() - if not self.add_connection(self, connection_id="ssh_connection", session_id=self.session_id): - self.sys_log.warning( - f"{self.name}: Connect request for {self.src_ip} declined. Service is at capacity." - ) - else: - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - self.is_connected = True - - payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) - return payload - - # %% - # Copy + Paste from Terminal Wiki + return self.ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=self.user_account) def ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: """Remote login to terminal via SSH.""" - if user_account: + if not user_account: # Setting default creds (Best to use this until we have more clarification around user accounts) - self.user_account = {self.user_name: "placeholder", self.password: "placeholder"} + user_account = {self.user_name: "placeholder", self.password: "placeholder"} # Implement SSHPacket class payload: SSHPacket = SSHPacket( @@ -161,6 +166,7 @@ class Terminal(Service): connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, user_account=user_account, ) + # self.send will return bool, payload unchanged? if self.send(payload=payload, dest_ip_address=dest_ip_address): if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") @@ -168,24 +174,46 @@ class Terminal(Service): self.add_connection(self, connection_id=self.uuid, session_id=self.session_id) return True else: - self.sys_log.error("Payload type incorrect, Login Failed") + self.sys_log.error("Login Failed. Incorrect credentials provided.") return False else: - self.sys_log.error("Incorrect credentials provided. Login Failed.") + self.sys_log.error("Login Failed. Incorrect credentials provided.") return False - # %% + def check_connection(self, connection_id: str) -> bool: + """Check whether the connection is valid.""" + if self.is_connected: + return self.send(dest_ip_address=self.dest_ip_address, connection_id=connection_id) + else: + return False - def receive_payload_from_software_manager( - self, - payload: Any, - dst_ip_address: Optional[Union[IPv4Address, IPv4Network]] = None, - src_port: Optional[Port] = None, - dst_port: Optional[Port] = None, - session_id: Optional[str] = None, - ip_protocol: IPProtocol = IPProtocol.TCP, - icmp_packet: Optional[ICMPPacket] = None, - connection_id: Optional[str] = None, - ) -> Union[Any, None]: - """Receive Software Manager Payload.""" - self._validate_login() + def disconnect(self, connection_id: str): + """Disconnect from remote.""" + self._disconnect(connection_id) + self.is_connected = False + + def _disconnect(self, connection_id: str) -> bool: + if not self.is_connected: + return False + + if len(self.user_connections) == 0: + self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") + return False + if not self.user_connections.get(connection_id): + return False + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": connection_id}, + dest_ip_address=self.server_ip_address, + dest_port=self.port, + ) + connection = self.user_connections.pop(connection_id) + self.terminate_connection(connection_id=connection_id) + + connection.is_active = False + + self.sys_log.info( + f"{self.name}: Disconnected {connection_id} from: {self.user_connections[connection_id]._dest_ip_address}" + ) + self.connected = False + return True From 252214b4689444fb6f54f53a9ae9199158b1c5e4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 8 Jul 2024 08:25:42 +0100 Subject: [PATCH 10/95] #2711 Updating Changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17bf3557..1f2db4f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,8 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Made observation space flattening optional (on by default). To turn off for an agent, change the `agent_settings.flatten_obs` setting in the config. - Added support for SQL INSERT command. - Added ability to log each agent's action choices in each step to a JSON file. -- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present. +- Removal of Link bandwidth hardcoding. This can now be configured via the network configuration yaml. Will default to 100 if not present. - Added NMAP application to all host and layer-3 network nodes. +- Added Terminal Class for HostNode components ### Bug Fixes From 47a1daa5806bb611ff3bb6ee12368e6dd8d53a52 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 8 Jul 2024 15:10:06 +0100 Subject: [PATCH 11/95] #2735 - Initial work done around User, UserManager, and UserSessionManager --- .../simulator/network/hardware/base.py | 22 ++- .../network/hardware/nodes/host/host_node.py | 9 +- .../system/services/access/__init__.py | 1 + .../system/services/access/user_manager.py | 186 ++++++++++++++++++ .../services/access/user_session_manager.py | 98 +++++++++ .../simulator/system/services/service.py | 2 +- src/primaite/simulator/system/software.py | 2 +- 7 files changed, 308 insertions(+), 12 deletions(-) create mode 100644 src/primaite/simulator/system/services/access/__init__.py create mode 100644 src/primaite/simulator/system/services/access/user_manager.py create mode 100644 src/primaite/simulator/system/services/access/user_session_manager.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 6942d280..ada9c57a 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,7 +6,7 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, Optional, TypeVar, Union +from typing import Any, ClassVar, Dict, Optional, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field @@ -37,6 +37,8 @@ from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.processes.process import Process +from primaite.simulator.system.services.access.user_manager import UserManager +from primaite.simulator.system.services.access.user_session_manager import UserSessionManager from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware from primaite.utils.converters import convert_dict_enum_keys_to_enum_values @@ -821,7 +823,16 @@ class Node(SimComponent): super().__init__(**kwargs) self.session_manager.node = self self.session_manager.software_manager = self.software_manager - self._install_system_software() + self.software_manager.install(UserSessionManager) + self.software_manager.install(UserManager) + + # @property + # def user_manager(self) -> UserManager: + # return self.software_manager.software["UserManager"] # noqa + # + # @property + # def _user_session_manager(self) -> UserSessionManager: + # return self.software_manager.software["UserSessionManager"] # noqa def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: """ @@ -876,7 +887,7 @@ class Node(SimComponent): @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator.""" - return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on." + return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on." def _init_request_manager(self) -> RequestManager: """ @@ -1000,10 +1011,6 @@ class Node(SimComponent): return rm - def _install_system_software(self): - """Install System Software - software that is usually provided with the OS.""" - pass - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -1184,6 +1191,7 @@ class Node(SimComponent): def pre_timestep(self, timestep: int) -> None: """Apply pre-timestep logic.""" super().pre_timestep(timestep) + self._ for network_interface in self.network_interfaces.values(): network_interface.pre_timestep(timestep=timestep) diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index fdb28339..80f80a04 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -11,6 +11,8 @@ from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.access.user_manager import UserManager +from primaite.simulator.system.services.access.user_session_manager import UserSessionManager from primaite.simulator.system.services.arp.arp import ARP, ARPPacket from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.icmp.icmp import ICMP @@ -306,6 +308,8 @@ class HostNode(Node): "NTPClient": NTPClient, "WebBrowser": WebBrowser, "NMAP": NMAP, + # "UserSessionManager": UserSessionManager, + # "UserManager": UserManager, } """List of system software that is automatically installed on nodes.""" @@ -314,9 +318,10 @@ class HostNode(Node): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): + def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, username: str, password: str, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + self.user_manager.add_user(username=username, password=password, is_admin=True, bypass_can_perform_action=True) @property def nmap(self) -> Optional[NMAP]: @@ -348,8 +353,6 @@ class HostNode(Node): for _, software_class in self.SYSTEM_SOFTWARE.items(): self.software_manager.install(software_class) - super()._install_system_software() - def default_gateway_hello(self): """ Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address. diff --git a/src/primaite/simulator/system/services/access/__init__.py b/src/primaite/simulator/system/services/access/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/src/primaite/simulator/system/services/access/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/src/primaite/simulator/system/services/access/user_manager.py b/src/primaite/simulator/system/services/access/user_manager.py new file mode 100644 index 00000000..09f8950e --- /dev/null +++ b/src/primaite/simulator/system/services/access/user_manager.py @@ -0,0 +1,186 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Dict, Optional + +from prettytable import MARKDOWN, PrettyTable +from pydantic import Field + +from primaite.simulator.core import SimComponent +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.service import Service + + +class User(SimComponent): + """ + Represents a user in the PrimAITE system. + + :param username: The username of the user + :param password: The password of the user + :param disabled: Boolean flag indicating whether the user is disabled + :param is_admin: Boolean flag indicating whether the user has admin privileges + """ + + username: str + password: str + disabled: bool = False + is_admin: bool = False + + def describe_state(self) -> Dict: + """ + Returns a dictionary representing the current state of the user. + + :return: A dict containing the state of the user + """ + return self.model_dump() + + +class UserManager(Service): + """ + Manages users within the PrimAITE system, handling creation, authentication, and administration. + + :param users: A dictionary of all users by their usernames + :param admins: A dictionary of admin users by their usernames + :param disabled_admins: A dictionary of currently disabled admin users by their usernames + """ + + users: Dict[str, User] = Field(default_factory=dict) + admins: Dict[str, User] = Field(default_factory=dict) + disabled_admins: Dict[str, User] = Field(default_factory=dict) + + def __init__(self, **kwargs): + """ + Initializes a UserManager instanc. + + :param username: The username for the default admin user + :param password: The password for the default admin user + """ + kwargs["name"] = "UserManager" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) + self.start() + + def describe_state(self) -> Dict: + """ + Returns the state of the UserManager along with the number of users and admins. + + :return: A dict containing detailed state information + """ + state = super().describe_state() + state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)}) + return state + + def show(self, markdown: bool = False): + """ + Display the Users. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ + table = PrettyTable(["Username", "Admin", "Enabled"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} User Manager)" + for user in self.users.values(): + table.add_row([user.username, user.is_admin, user.disabled]) + print(table.get_string(sortby="Username")) + + def _is_last_admin(self, username: str) -> bool: + return username in self.admins and len(self.admins) == 1 + + def add_user( + self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False + ) -> bool: + """ + Adds a new user to the system. + + :param username: The username for the new user + :param password: The password for the new user + :param is_admin: Flag indicating if the new user is an admin + :return: True if user was successfully added, False otherwise + """ + if not bypass_can_perform_action and not self._can_perform_action(): + return False + if username in self.users: + return False + user = User(username=username, password=password, is_admin=is_admin) + self.users[username] = user + if is_admin: + self.admins[username] = user + self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}") + return True + + def authenticate_user(self, username: str, password: str) -> Optional[User]: + """ + Authenticates a user's login attempt. + + :param username: The username of the user trying to log in + :param password: The password provided by the user + :return: The User object if authentication is successful, None otherwise + """ + if not self._can_perform_action(): + return None + user = self.users.get(username) + if user and not user.disabled and user.password == password: + self.sys_log.info(f"{self.name}: User authenticated: {username}") + return user + self.sys_log.info(f"{self.name}: Authentication failed for: {username}") + return None + + def change_user_password(self, username: str, current_password: str, new_password: str) -> bool: + """ + Changes a user's password. + + :param username: The username of the user changing their password + :param current_password: The current password of the user + :param new_password: The new password for the user + :return: True if the password was changed successfully, False otherwise + """ + if not self._can_perform_action(): + return False + user = self.users.get(username) + if user and user.password == current_password: + user.password = new_password + self.sys_log.info(f"{self.name}: Password changed for {username}") + return True + self.sys_log.info(f"{self.name}: Password change failed for {username}") + return False + + def disable_user(self, username: str) -> bool: + """ + Disables a user account, preventing them from logging in. + + :param username: The username of the user to disable + :return: True if the user was disabled successfully, False otherwise + """ + if not self._can_perform_action(): + return False + if username in self.users and not self.users[username].disabled: + if self._is_last_admin(username): + self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin") + return False + self.users[username].disabled = True + self.sys_log.info(f"{self.name}: User disabled: {username}") + if username in self.admins: + self.disabled_admins[username] = self.admins.pop(username) + return True + self.sys_log.info(f"{self.name}: Failed to disable user: {username}") + return False + + def enable_user(self, username: str) -> bool: + """ + Enables a previously disabled user account. + + :param username: The username of the user to enable + :return: True if the user was enabled successfully, False otherwise + """ + if not self._can_perform_action(): + return False + if username in self.users and self.users[username].disabled: + self.users[username].disabled = False + self.sys_log.info(f"{self.name}: User enabled: {username}") + if username in self.disabled_admins: + self.admins[username] = self.disabled_admins.pop(username) + return True + self.sys_log.info(f"{self.name}: Failed to enable user: {username}") + return False diff --git a/src/primaite/simulator/system/services/access/user_session_manager.py b/src/primaite/simulator/system/services/access/user_session_manager.py new file mode 100644 index 00000000..03d2dd93 --- /dev/null +++ b/src/primaite/simulator/system/services/access/user_session_manager.py @@ -0,0 +1,98 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Dict, List, Optional +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from primaite.simulator.core import SimComponent +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.access.user_manager import User, UserManager +from primaite.simulator.system.services.service import Service +from primaite.utils.validators import IPV4Address + + +class UserSession(SimComponent): + user: User + start_step: int + last_active_step: int + end_step: Optional[int] = None + local: bool = True + + @classmethod + def create(cls, user: User, timestep: int) -> UserSession: + return UserSession(user=user, start_step=timestep, last_active_step=timestep) + def describe_state(self) -> Dict: + return self.model_dump() + + +class RemoteUserSession(UserSession): + remote_ip_address: IPV4Address + local: bool = False + + def describe_state(self) -> Dict: + state = super().describe_state() + state["remote_ip_address"] = str(self.remote_ip_address) + return state + + +class UserSessionManager(BaseModel): + node: + local_session: Optional[UserSession] = None + remote_sessions: Dict[str, RemoteUserSession] = Field(default_factory=dict) + historic_sessions: List[UserSession] = Field(default_factory=list) + + local_session_timeout_steps: int = 30 + remote_session_timeout_steps: int = 5 + max_remote_sessions: int = 3 + + current_timestep: int = 0 + + @property + def _user_manager(self) -> UserManager: + return self.software_manager.software["UserManager"] # noqa + + def pre_timestep(self, timestep: int) -> None: + """Apply any pre-timestep logic that helps make sure we have the correct observations.""" + self.current_timestep = timestep + if self.local_session: + if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep: + self._timeout_session(self.local_session) + + def _timeout_session(self, session: UserSession) -> None: + session.end_step = self.current_timestep + session_identity = session.user.username + if session.local: + self.local_session = None + session_type = "Local" + else: + self.remote_sessions.pop(session.uuid) + session_type = "Remote" + session_identity = f"{session_identity} {session.remote_ip_address}" + + self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") + + def login(self, username: str, password: str) -> Optional[str]: + if not self._can_perform_action(): + return None + user = self._user_manager.authenticate_user(username=username, password=password) + if user: + self.logout() + self.local_session = UserSession.create(user=user, timestep=self.current_timestep) + self.sys_log.info(f"{self.name}: User {user.username} logged in") + return self.local_session.uuid + else: + self.sys_log.info(f"{self.name}: Incorrect username or password") + + def logout(self): + if not self._can_perform_action(): + return False + if self.local_session: + session = self.local_session + session.end_step = self.current_timestep + self.historic_sessions.append(session) + self.local_session = None + self.sys_log.info(f"{self.name}: User {session.user.username} logged out") diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index e6ce2c87..bef9804f 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -43,7 +43,7 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." - def __init__(self, **kwargs): + def __init__(self, **kwargs):c super().__init__(**kwargs) def _can_perform_action(self) -> bool: diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 7ea67dcd..7c27534a 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -291,7 +291,7 @@ class IOSoftware(Software): """ if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON: self.software_manager.node.sys_log.error( - f"{self.name} Error: {self.software_manager.node.hostname} is not online." + f"{self.name} Error: {self.software_manager.node.hostname} is not powered on." ) return False return True From 42602be953470c61caad47cfe9e813a6c440fa28 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 9 Jul 2024 11:54:33 +0100 Subject: [PATCH 12/95] #2710 - Initial implementation f the receive/send methods. Committing to change branch --- .../network/hardware/nodes/host/host_node.py | 1 + .../system/services/terminal/terminal.py | 68 +++++++++++++++++-- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 5848ade4..1fb936cd 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -293,6 +293,7 @@ class HostNode(Node): * DNS (Domain Name System) Client: Resolves domain names to IP addresses. * FTP (File Transfer Protocol) Client: Enables file transfers between the host and FTP servers. * NTP (Network Time Protocol) Client: Synchronizes the system clock with NTP servers. + * Terminal Client: Handles SSH requests between HostNode and external components. Applications: ------------ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 5f8719ac..bf852823 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -2,14 +2,14 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from uuid import uuid4 from pydantic import BaseModel from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager -from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -27,7 +27,7 @@ class TerminalClientConnection(BaseModel): connection_id: str """Connection UUID.""" - parent_node: HostNode + parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" is_active: bool = True @@ -116,7 +116,7 @@ class Terminal(Service): self.process_login(dest_ip_address=dest_ip_address, user_account=user_account) def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Processes the login attempt. Returns a SSHPacket which either rejects the login or accepts it.""" + """Processes the login attempt. Returns a bool which either rejects the login or accepts it.""" # we assume that the login fails unless we meet all the criteria. transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED @@ -142,6 +142,56 @@ class Terminal(Service): self.send(payload=payload, dest_ip_address=dest_ip_address) return True + def validate_user(self, user: Dict[str]) -> bool: + return True if user.get("username") in self.user_connections else False + + + def _ssh_process_logoff(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: + """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" + + if self.validate_user(user_account): + # Account is logged in + self.user_connections.pop[user_account["username"]] # assumption atm + self.is_connected = False + return True + else: + self.sys_log.warning("User account credentials invalid.") + + def _ssh_process_command(self, session_id: str, *args, **kwargs) -> bool: + return True + + def send_logoff_ack(self): + """Send confirmation of successful disconnect""" + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE + payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message, ssh_output=RequestResponse(status="success")) + self.send(payload=payload) + + def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") + if payload.connection_message ==SSHConnectionMessage. SSH_MSG_CHANNEL_CLOSE: + result = self._ssh_process_logoff(session_id=session_id) + # We need to close on the other machine as well + self.send_logoff_ack() + + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: + src_ip = kwargs.get("frame").ip.src_ip_address + user_account = payload.get("user_account", {}) + result = self._ssh_process_login(src_ip=src_ip, session_id=session_id, user_account=user_account) + + elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: + # Ensure we only ever process requests if we have a established connection (e.g session_id is provided and validated) + result = self._ssh_process_command(session_id=session_id) + + else: + self.sys_log.warning("Encounter unexpected message type, rejecting connection") + # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED + return False + + self.send(payload=result, session_id=session_id) + return True + + # %% Outbound def login(self, dest_ip_address: IPv4Address) -> bool: @@ -217,3 +267,13 @@ class Terminal(Service): ) self.connected = False return True + + + def send( + self, + payload: SSHPacket, + dest_ip_address: Optional[IPv4Address] = None, + session_id: Optional[str] = None, + ) -> bool: + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) + From 8061102587f36c180203b60e1cb167427e1147c9 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 9 Jul 2024 11:55:16 +0100 Subject: [PATCH 13/95] #2710 - commit before changing branch --- .../system/services/terminal/terminal.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index bf852823..1dd3133d 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -27,7 +27,7 @@ class TerminalClientConnection(BaseModel): connection_id: str """Connection UUID.""" - parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. + parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" is_active: bool = True @@ -145,13 +145,12 @@ class Terminal(Service): def validate_user(self, user: Dict[str]) -> bool: return True if user.get("username") in self.user_connections else False - def _ssh_process_logoff(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - + if self.validate_user(user_account): # Account is logged in - self.user_connections.pop[user_account["username"]] # assumption atm + self.user_connections.pop[user_account["username"]] # assumption atm self.is_connected = False return True else: @@ -164,33 +163,36 @@ class Terminal(Service): """Send confirmation of successful disconnect""" transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE - payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message, ssh_output=RequestResponse(status="success")) + payload: SSHPacket = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + ssh_output=RequestResponse(status="success"), + ) self.send(payload=payload) def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: - self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message ==SSHConnectionMessage. SSH_MSG_CHANNEL_CLOSE: - result = self._ssh_process_logoff(session_id=session_id) - # We need to close on the other machine as well - self.send_logoff_ack() + self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") + if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + result = self._ssh_process_logoff(session_id=session_id) + # We need to close on the other machine as well + self.send_logoff_ack() - elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - src_ip = kwargs.get("frame").ip.src_ip_address - user_account = payload.get("user_account", {}) - result = self._ssh_process_login(src_ip=src_ip, session_id=session_id, user_account=user_account) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: + src_ip = kwargs.get("frame").ip.src_ip_address + user_account = payload.get("user_account", {}) + result = self._ssh_process_login(src_ip=src_ip, session_id=session_id, user_account=user_account) - elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: - # Ensure we only ever process requests if we have a established connection (e.g session_id is provided and validated) - result = self._ssh_process_command(session_id=session_id) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: + # Ensure we only ever process requests if we have a established connection (e.g session_id is provided and validated) + result = self._ssh_process_command(session_id=session_id) - else: - self.sys_log.warning("Encounter unexpected message type, rejecting connection") - # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED - return False - - self.send(payload=result, session_id=session_id) - return True + else: + self.sys_log.warning("Encounter unexpected message type, rejecting connection") + # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED + return False + self.send(payload=result, session_id=session_id) + return True # %% Outbound @@ -268,12 +270,10 @@ class Terminal(Service): self.connected = False return True - def send( - self, - payload: SSHPacket, - dest_ip_address: Optional[IPv4Address] = None, - session_id: Optional[str] = None, - ) -> bool: - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) - + self, + payload: SSHPacket, + dest_ip_address: Optional[IPv4Address] = None, + session_id: Optional[str] = None, + ) -> bool: + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) From dc3558bc4dbf4c4afb70cb1ae6e0a7b973e6bc89 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 10 Jul 2024 17:39:45 +0100 Subject: [PATCH 14/95] #2710 - End of Day commit --- .../simulator/network/protocols/ssh.py | 2 + .../system/services/terminal/terminal.py | 74 ++++++++++++++----- tests/integration_tests/system/test_nmap.py | 2 +- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 7be81982..361c2552 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -56,6 +56,8 @@ class SSHConnectionMessage(IntEnum): SSH_MSG_CHANNEL_CLOSE = 87 """Closes the channel.""" + SSH_LOGOFF_ACK = 89 + """Logoff confirmation acknowledgement""" class SSHPacket(DataPacket): """Represents an SSHPacket.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 1dd3133d..e5ff9054 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -24,8 +24,8 @@ class TerminalClientConnection(BaseModel): This class is used to record current User Connections within the Terminal class. """ - connection_id: str - """Connection UUID.""" + session_id: str + """Session UUID.""" parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" @@ -76,6 +76,8 @@ class Terminal(Service): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + # %% Util + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -100,6 +102,22 @@ class Terminal(Service): rm = super()._init_request_manager() return rm + def _validate_login(self, user_account: Optional[str]) -> bool: + """Validate login credentials are valid.""" + # Pending login/Usermanager implementation + if user_account: + # validate bits - poke UserManager with provided info + # return self.user_manager.validate(user_account) + pass + else: + pass + # user_account = next(iter(self.user_connections)) + # return self.user_manager.validate(user_account) + + return True + + + # %% Inbound def _generate_connection_id(self) -> str: @@ -142,40 +160,50 @@ class Terminal(Service): self.send(payload=payload, dest_ip_address=dest_ip_address) return True - def validate_user(self, user: Dict[str]) -> bool: - return True if user.get("username") in self.user_connections else False + def validate_user(self, session_id: str) -> bool: + return True - def _ssh_process_logoff(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: + def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - if self.validate_user(user_account): + if self.validate_user(session_id): # Account is logged in - self.user_connections.pop[user_account["username"]] # assumption atm - self.is_connected = False return True else: self.sys_log.warning("User account credentials invalid.") + return False def _ssh_process_command(self, session_id: str, *args, **kwargs) -> bool: return True - def send_logoff_ack(self): + def send_logoff_ack(self, session_id: str): """Send confirmation of successful disconnect""" transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE + connection_message = SSHConnectionMessage.SSH_LOGOFF_ACK payload: SSHPacket = SSHPacket( transport_message=transport_message, connection_message=connection_message, - ssh_output=RequestResponse(status="success"), + ssh_output=RequestResponse(status="success", data={"reason": "Successfully Disconnected"}), ) - self.send(payload=payload) + self.send(payload=payload, session_id=session_id) def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + # shouldn't be expecting to see anything other than SSHPacket payloads currently + # confirm that we are receiving the + if not isinstance(payload, SSHPacket): + return False self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: - result = self._ssh_process_logoff(session_id=session_id) + + if payload.connection_message == SSHConnectionMessage.SSH_LOGOFF_ACK: + # Logoff acknowledgement received. NFA needed. + self.sys_log.debug("Received confirmation of successful disconnect") + return True + + elif payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + self._ssh_process_logoff(session_id=session_id) + self.sys_log.debug("Disconnect message received, sending logoff ack") # We need to close on the other machine as well - self.send_logoff_ack() + self.send_logoff_ack(session_id=session_id) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: src_ip = kwargs.get("frame").ip.src_ip_address @@ -191,12 +219,13 @@ class Terminal(Service): # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED return False - self.send(payload=result, session_id=session_id) + # self.send(payload=result, session_id=session_id) return True + # %% Outbound - def login(self, dest_ip_address: IPv4Address) -> bool: + def login(self, dest_ip_address: IPv4Address, user_account: dict[str]) -> bool: """ Perform an initial login request. @@ -204,13 +233,14 @@ class Terminal(Service): """ # TODO: This will need elaborating when user accounts are implemented self.sys_log.info("Attempting Login") - return self.ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=self.user_account) + return self._ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=user_account) - def ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: + def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: """Remote login to terminal via SSH.""" if not user_account: # Setting default creds (Best to use this until we have more clarification around user accounts) user_account = {self.user_name: "placeholder", self.password: "placeholder"} + # something like self.user_manager.get_user_details ? # Implement SSHPacket class payload: SSHPacket = SSHPacket( @@ -275,5 +305,9 @@ class Terminal(Service): payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None, + user_account: Optional[str] = None, ) -> bool: - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=Port.SSH, session_id=session_id) + """Send a payload out from the Terminal.""" + self._validate_login(user_account) + self.sys_log.debug(f"Sending payload: {payload} from session: {session_id}") + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id) diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index bbfa4f43..a261f272 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -106,7 +106,7 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): expected_result = { IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], + IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS, Port.SSH], IPProtocol.UDP: [Port.ARP, Port.NTP], }, } From 2eb36149b28a55cdea48e6d8ea63f6e883de9112 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 15 Jul 2024 08:20:11 +0100 Subject: [PATCH 15/95] #2710 - Prep for draft PR --- .../simulator/network/hardware/base.py | 1 - .../simulator/network/protocols/ssh.py | 1 + .../services/database/database_service.py | 2 +- .../system/services/terminal/terminal.py | 189 ++++++++---------- .../_system/_services/test_terminal.py | 112 +++++++++++ 5 files changed, 199 insertions(+), 106 deletions(-) create mode 100644 tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 6942d280..610dd071 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -256,7 +256,6 @@ class NetworkInterface(SimComponent, ABC): """ # Determine the direction of the traffic direction = "inbound" if inbound else "outbound" - # Initialize protocol and port variables protocol = None port = None diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 361c2552..7d1f915e 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -59,6 +59,7 @@ class SSHConnectionMessage(IntEnum): SSH_LOGOFF_ACK = 89 """Logoff confirmation acknowledgement""" + class SSHPacket(DataPacket): """Represents an SSHPacket.""" diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..d6feafbd 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -19,7 +19,7 @@ _LOGGER = getLogger(__name__) class DatabaseService(Service): """ - A class for simulating a generic SQL Server service. +A class for simulating a generic SQL Server service. This class inherits from the `Service` class and provides methods to simulate a SQL database. """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e5ff9054..3324c4e4 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from uuid import uuid4 from pydantic import BaseModel @@ -24,9 +24,6 @@ class TerminalClientConnection(BaseModel): This class is used to record current User Connections within the Terminal class. """ - session_id: str - """Session UUID.""" - parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. """The parent Node that this connection was created on.""" @@ -104,34 +101,44 @@ class Terminal(Service): def _validate_login(self, user_account: Optional[str]) -> bool: """Validate login credentials are valid.""" - # Pending login/Usermanager implementation - if user_account: - # validate bits - poke UserManager with provided info - # return self.user_manager.validate(user_account) - pass + # TODO: Interact with UserManager to check user_account details + if len(self.user_connections) == 0: + # No current connections + self.sys_log.warning("Login Required!") + return False else: - pass - # user_account = next(iter(self.user_connections)) - # return self.user_manager.validate(user_account) - - return True - - + return True # %% Inbound - def _generate_connection_id(self) -> str: + def _generate_connection_uuid(self) -> str: """Generate a unique connection ID.""" return str(uuid4()) - def process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Process User request to login to Terminal.""" - if user_account in self.user_connections: + def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: + """Process User request to login to Terminal. + + :param dest_ip_address: The IP address of the node we want to connect to. + :return: True if successful, False otherwise. + """ + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning("Cannot process login as service is not running") + return False + user_account = f"Username: placeholder, Password: placeholder" + if self.connection_uuid in self.user_connections: self.sys_log.debug("User authentication passed") return True else: - self._ssh_process_login(dest_ip_address=dest_ip_address, user_account=user_account) - self.process_login(dest_ip_address=dest_ip_address, user_account=user_account) + # Need to send a login request + # TODO: Refactor with UserManager changes to provide correct credentials and validate. + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN + payload: SSHPacket = SSHPacket(payload="login", + transport_message=transport_message, + connection_message=connection_message) + + self.sys_log.debug(f"Sending login request to {dest_ip_address}") + self.send(payload=payload, dest_ip_address=dest_ip_address) def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: """Processes the login attempt. Returns a bool which either rejects the login or accepts it.""" @@ -140,19 +147,20 @@ class Terminal(Service): connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED # Hard coded at current - replace with another method to handle local accounts. - if user_account == f"{self.user_name:} placeholder, {self.password:} placeholder": # hardcoded - connection_id = self._generate_connection_id() - if not self.add_connection(self, connection_id=connection_id): + if user_account == "Username: placeholder, Password: placeholder": # hardcoded + self.connection_uuid = self._generate_connection_uuid() + if not self.add_connection(connection_id=self.connection_uuid): self.sys_log.warning( f"{self.name}: Connect request for {dest_ip_address} declined. Service is at capacity." ) return False else: - self.sys_log.info(f"{self.name}: Connect request for ID: {connection_id} authorised") + self.sys_log.info(f"{self.name}: Connect request for ID: {self.connection_uuid} authorised") transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - new_connection = TerminalClientConnection(connection_id=connection_id, dest_ip_address=dest_ip_address) - self.user_connections[connection_id] = new_connection + new_connection = TerminalClientConnection(parent_node = self.software_manager.node, + connection_id=self.connection_uuid, dest_ip_address=dest_ip_address) + self.user_connections[self.connection_uuid] = new_connection self.is_connected = True payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) @@ -160,86 +168,51 @@ class Terminal(Service): self.send(payload=payload, dest_ip_address=dest_ip_address) return True - def validate_user(self, session_id: str) -> bool: - return True - def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - - if self.validate_user(session_id): - # Account is logged in - return True - else: - self.sys_log.warning("User account credentials invalid.") - return False - - def _ssh_process_command(self, session_id: str, *args, **kwargs) -> bool: - return True - - def send_logoff_ack(self, session_id: str): - """Send confirmation of successful disconnect""" - transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS - connection_message = SSHConnectionMessage.SSH_LOGOFF_ACK - payload: SSHPacket = SSHPacket( - transport_message=transport_message, - connection_message=connection_message, - ssh_output=RequestResponse(status="success", data={"reason": "Successfully Disconnected"}), - ) - self.send(payload=payload, session_id=session_id) + # TODO: Should remove def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: - # shouldn't be expecting to see anything other than SSHPacket payloads currently - # confirm that we are receiving the + """Receive Payload and process for a response.""" if not isinstance(payload, SSHPacket): return False + + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning(f"Cannot process message as not running") + return False + self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message == SSHConnectionMessage.SSH_LOGOFF_ACK: - # Logoff acknowledgement received. NFA needed. - self.sys_log.debug("Received confirmation of successful disconnect") - return True - - elif payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + connection_id = kwargs["connection_id"] + dest_ip_address = kwargs["dest_ip_address"] self._ssh_process_logoff(session_id=session_id) - self.sys_log.debug("Disconnect message received, sending logoff ack") + self.disconnect(dest_ip_address=dest_ip_address) + self.sys_log.debug(f"Disconnecting {connection_id}") # We need to close on the other machine as well - self.send_logoff_ack(session_id=session_id) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - src_ip = kwargs.get("frame").ip.src_ip_address - user_account = payload.get("user_account", {}) - result = self._ssh_process_login(src_ip=src_ip, session_id=session_id, user_account=user_account) + # validate login + user_account = "Username: placeholder, Password: placeholder" + self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account) - elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: - # Ensure we only ever process requests if we have a established connection (e.g session_id is provided and validated) - result = self._ssh_process_command(session_id=session_id) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: + self.sys_log.debug("Login Successful") + self.is_connected = True + return True else: self.sys_log.warning("Encounter unexpected message type, rejecting connection") - # send a SSH_MSG_CHANNEL_CLOSE if there is a session_id otherwise SSH_MSG_OPEN_FAILED return False - # self.send(payload=result, session_id=session_id) return True - # %% Outbound - - def login(self, dest_ip_address: IPv4Address, user_account: dict[str]) -> bool: - """ - Perform an initial login request. - - If this fails, raises an error. - """ - # TODO: This will need elaborating when user accounts are implemented - self.sys_log.info("Attempting Login") - return self._ssh_remote_login(self, dest_ip_address=dest_ip_address, user_account=user_account) - def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: """Remote login to terminal via SSH.""" if not user_account: - # Setting default creds (Best to use this until we have more clarification around user accounts) - user_account = {self.user_name: "placeholder", self.password: "placeholder"} + # TODO: Generic hardcoded info, will need to be updated with UserManager. + user_account = f"Username: placeholder, Password: placeholder" # something like self.user_manager.get_user_details ? # Implement SSHPacket class @@ -248,7 +221,6 @@ class Terminal(Service): connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, user_account=user_account, ) - # self.send will return bool, payload unchanged? if self.send(payload=payload, dest_ip_address=dest_ip_address): if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") @@ -269,45 +241,54 @@ class Terminal(Service): else: return False - def disconnect(self, connection_id: str): - """Disconnect from remote.""" - self._disconnect(connection_id) + def disconnect(self, dest_ip_address: IPv4Address) -> bool: + """Disconnect from remote connection. + + :param dest_ip_address: The IP address fo the connection we are terminating. + :return: True if successful, False otherwise. + """ + self._disconnect(dest_ip_address=dest_ip_address) self.is_connected = False - def _disconnect(self, connection_id: str) -> bool: + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: if not self.is_connected: return False if len(self.user_connections) == 0: self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") return False - if not self.user_connections.get(connection_id): + if not self.user_connections.get(self.connection_uuid): return False software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": connection_id}, - dest_ip_address=self.server_ip_address, - dest_port=self.port, + payload={"type": "disconnect", "connection_id": self.connection_uuid}, + dest_ip_address=dest_ip_address, + dest_port=self.port ) - connection = self.user_connections.pop(connection_id) - self.terminate_connection(connection_id=connection_id) + connection = self.user_connections.pop(self.connection_uuid) connection.is_active = False self.sys_log.info( - f"{self.name}: Disconnected {connection_id} from: {self.user_connections[connection_id]._dest_ip_address}" + f"{self.name}: Disconnected {self.connection_uuid}" ) - self.connected = False return True def send( self, payload: SSHPacket, - dest_ip_address: Optional[IPv4Address] = None, - session_id: Optional[str] = None, - user_account: Optional[str] = None, + dest_ip_address: IPv4Address, ) -> bool: - """Send a payload out from the Terminal.""" - self._validate_login(user_account) - self.sys_log.debug(f"Sending payload: {payload} from session: {session_id}") - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id) + """ + Send a payload out from the Terminal. + + :param payload: The payload to be sent. + :param dest_up_address: The IP address of the payload destination. + """ + if self.operating_state != ServiceOperatingState.RUNNING: + self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") + return False + self.sys_log.debug(f"Sending payload: {payload}") + return super().send( + payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port + ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py new file mode 100644 index 00000000..62933b5c --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -0,0 +1,112 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import Terminal +from primaite.simulator.system.software import SoftwareHealthState + +@pytest.fixture(scope="function") +def terminal_on_computer() -> Tuple[Terminal, Computer]: + computer: Computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer.power_on() + terminal: Terminal = computer.software_manager.software.get("Terminal") + + return [terminal, computer] + +@pytest.fixture(scope="function") +def basic_network() -> Network: + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_a.software_manager.get_open_ports() + + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + return network + + +def test_terminal_creation(terminal_on_computer): + terminal, computer = terminal_on_computer + terminal.describe_state() + +def test_terminal_install_default(): + """Terminal should be auto installed onto Nodes""" + computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer.power_on() + + assert computer.software_manager.software.get("Terminal") + +def test_terminal_not_on_switch(): + """Ensure terminal does not auto-install to switch""" + test_switch = Switch(hostname="Test") + + assert not test_switch.software_manager.software.get("Terminal") + +def test_terminal_send(basic_network): + """Check that Terminal can send """ + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + payload: SSHPacket = SSHPacket(payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN) + + + assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + + +def test_terminal_fail_when_closed(basic_network): + """Ensure Terminal won't attempt to send/receive when off""" + network: Network = basic_network + computer: Computer = network.get_node_by_hostname("node_a") + terminal: Terminal = computer.software_manager.software.get("Terminal") + + terminal.operating_state = ServiceOperatingState.STOPPED + + assert terminal.login(dest_ip_address="192.168.0.11") is False + + +def test_terminal_disconnect(basic_network): + """Terminal should set is_connected to false on disconnect""" + network: Network = basic_network + computer: Computer = network.get_node_by_hostname("node_a") + terminal: Terminal = computer.software_manager.software.get("Terminal") + + assert terminal.is_connected is False + + terminal.login(dest_ip_address="192.168.0.11") + + assert terminal.is_connected is True + + terminal.disconnect(dest_ip_address="192.168.0.11") + + assert terminal.is_connected is False + +def test_terminal_ignores_when_off(basic_network): + """Terminal should ignore commands when not running""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + computer_b: Computer = network.get_node_by_hostname("node_b") + + terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b + + assert terminal_a.is_connected is True + + terminal_a.operating_state = ServiceOperatingState.STOPPED + + payload: SSHPacket = SSHPacket(payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA) + + assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") From 32c2ea0b100e39e6db28b14e1f939852c7ca2c21 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 15 Jul 2024 08:22:18 +0100 Subject: [PATCH 16/95] #2710 - Pre-commit run ahead of raising PR --- .../services/database/database_service.py | 4 +- .../system/services/terminal/terminal.py | 38 +++++++++---------- .../_system/_services/test_terminal.py | 31 ++++++++++----- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index d6feafbd..f061b3c7 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -19,9 +19,9 @@ _LOGGER = getLogger(__name__) class DatabaseService(Service): """ -A class for simulating a generic SQL Server service. + A class for simulating a generic SQL Server service. - This class inherits from the `Service` class and provides methods to simulate a SQL database. + This class inherits from the `Service` class and provides methods to simulate a SQL database. """ password: Optional[str] = None diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 3324c4e4..589492ba 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -117,14 +117,13 @@ class Terminal(Service): def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: """Process User request to login to Terminal. - - :param dest_ip_address: The IP address of the node we want to connect to. + + :param dest_ip_address: The IP address of the node we want to connect to. :return: True if successful, False otherwise. """ if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning("Cannot process login as service is not running") return False - user_account = f"Username: placeholder, Password: placeholder" if self.connection_uuid in self.user_connections: self.sys_log.debug("User authentication passed") return True @@ -133,9 +132,9 @@ class Terminal(Service): # TODO: Refactor with UserManager changes to provide correct credentials and validate. transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN - payload: SSHPacket = SSHPacket(payload="login", - transport_message=transport_message, - connection_message=connection_message) + payload: SSHPacket = SSHPacket( + payload="login", transport_message=transport_message, connection_message=connection_message + ) self.sys_log.debug(f"Sending login request to {dest_ip_address}") self.send(payload=payload, dest_ip_address=dest_ip_address) @@ -158,8 +157,11 @@ class Terminal(Service): self.sys_log.info(f"{self.name}: Connect request for ID: {self.connection_uuid} authorised") transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - new_connection = TerminalClientConnection(parent_node = self.software_manager.node, - connection_id=self.connection_uuid, dest_ip_address=dest_ip_address) + new_connection = TerminalClientConnection( + parent_node=self.software_manager.node, + connection_id=self.connection_uuid, + dest_ip_address=dest_ip_address, + ) self.user_connections[self.connection_uuid] = new_connection self.is_connected = True @@ -170,7 +172,7 @@ class Terminal(Service): def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - # TODO: Should remove + # TODO: Should remove def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: """Receive Payload and process for a response.""" @@ -178,7 +180,7 @@ class Terminal(Service): return False if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning(f"Cannot process message as not running") + self.sys_log.warning("Cannot process message as not running") return False self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") @@ -212,7 +214,7 @@ class Terminal(Service): """Remote login to terminal via SSH.""" if not user_account: # TODO: Generic hardcoded info, will need to be updated with UserManager. - user_account = f"Username: placeholder, Password: placeholder" + user_account = "Username: placeholder, Password: placeholder" # something like self.user_manager.get_user_details ? # Implement SSHPacket class @@ -242,8 +244,8 @@ class Terminal(Service): return False def disconnect(self, dest_ip_address: IPv4Address) -> bool: - """Disconnect from remote connection. - + """Disconnect from remote connection. + :param dest_ip_address: The IP address fo the connection we are terminating. :return: True if successful, False otherwise. """ @@ -263,15 +265,13 @@ class Terminal(Service): software_manager.send_payload_to_session_manager( payload={"type": "disconnect", "connection_id": self.connection_uuid}, dest_ip_address=dest_ip_address, - dest_port=self.port + dest_port=self.port, ) connection = self.user_connections.pop(self.connection_uuid) connection.is_active = False - self.sys_log.info( - f"{self.name}: Disconnected {self.connection_uuid}" - ) + self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") return True def send( @@ -289,6 +289,4 @@ class Terminal(Service): self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") return False self.sys_log.debug(f"Sending payload: {payload}") - return super().send( - payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port - ) + return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 62933b5c..6b0365ce 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -11,14 +11,18 @@ from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import SoftwareHealthState + @pytest.fixture(scope="function") def terminal_on_computer() -> Tuple[Terminal, Computer]: - computer: Computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + computer: Computer = Computer( + hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0 + ) computer.power_on() terminal: Terminal = computer.software_manager.software.get("Terminal") return [terminal, computer] + @pytest.fixture(scope="function") def basic_network() -> Network: network = Network() @@ -37,6 +41,7 @@ def test_terminal_creation(terminal_on_computer): terminal, computer = terminal_on_computer terminal.describe_state() + def test_terminal_install_default(): """Terminal should be auto installed onto Nodes""" computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) @@ -44,22 +49,25 @@ def test_terminal_install_default(): assert computer.software_manager.software.get("Terminal") + def test_terminal_not_on_switch(): """Ensure terminal does not auto-install to switch""" test_switch = Switch(hostname="Test") assert not test_switch.software_manager.software.get("Terminal") + def test_terminal_send(basic_network): - """Check that Terminal can send """ + """Check that Terminal can send""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") - payload: SSHPacket = SSHPacket(payload="Test_Payload", - transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, - connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN) - + payload: SSHPacket = SSHPacket( + payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + ) assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") @@ -91,6 +99,7 @@ def test_terminal_disconnect(basic_network): assert terminal.is_connected is False + def test_terminal_ignores_when_off(basic_network): """Terminal should ignore commands when not running""" network: Network = basic_network @@ -99,14 +108,16 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b + terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b assert terminal_a.is_connected is True terminal_a.operating_state = ServiceOperatingState.STOPPED - payload: SSHPacket = SSHPacket(payload="Test_Payload", - transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, - connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA) + payload: SSHPacket = SSHPacket( + payload="Test_Payload", + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + ) assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") From fee7f202a66529564f422ea591bda3708bd04ee5 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 15 Jul 2024 10:06:28 +0100 Subject: [PATCH 17/95] #2711 - Amending some minor changes spotted whilst raising PR --- src/primaite/simulator/network/hardware/base.py | 1 + src/primaite/simulator/network/protocols/ssh.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 610dd071..6942d280 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -256,6 +256,7 @@ class NetworkInterface(SimComponent, ABC): """ # Determine the direction of the traffic direction = "inbound" if inbound else "outbound" + # Initialize protocol and port variables protocol = None port = None diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 7d1f915e..7be81982 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -56,9 +56,6 @@ class SSHConnectionMessage(IntEnum): SSH_MSG_CHANNEL_CLOSE = 87 """Closes the channel.""" - SSH_LOGOFF_ACK = 89 - """Logoff confirmation acknowledgement""" - class SSHPacket(DataPacket): """Represents an SSHPacket.""" From 34969c588b6a93134d77ecc518b64bed1988437d Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 16 Jul 2024 08:59:36 +0100 Subject: [PATCH 18/95] #2676: Fix mismerge. --- src/primaite/game/game.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c2a1961b..3e129879 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency -from primaite.simulator.network.hardware.base import NodeOperatingState +from primaite.simulator.network.hardware.base import NodeOperatingState, NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server From 07e736977ccefbe567cab88aed0c023c012c92d6 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 16 Jul 2024 16:58:11 +0100 Subject: [PATCH 19/95] #2676: Fix some more integration tests --- tests/assets/configs/bad_primaite_session.yaml | 2 +- tests/assets/configs/basic_switched_network.yaml | 2 +- tests/assets/configs/eval_only_primaite_session.yaml | 2 +- tests/assets/configs/firewall_actions_network.yaml | 2 +- tests/assets/configs/fix_duration_one_item.yaml | 2 +- tests/assets/configs/software_fix_duration.yaml | 2 +- tests/assets/configs/test_primaite_session.yaml | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 8cbd3ae9..c83cadc8 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -99,7 +99,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 69187fa3..fed0f52d 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -92,7 +92,7 @@ agents: - NONE tcp: - DNS - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index de861dcc..3d60eb6e 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -111,7 +111,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index fd5b1bf8..2292616d 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -68,7 +68,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/fix_duration_one_item.yaml b/tests/assets/configs/fix_duration_one_item.yaml index 59bc15f9..bd0fb61f 100644 --- a/tests/assets/configs/fix_duration_one_item.yaml +++ b/tests/assets/configs/fix_duration_one_item.yaml @@ -89,7 +89,7 @@ agents: - NONE tcp: - DNS - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/software_fix_duration.yaml b/tests/assets/configs/software_fix_duration.yaml index 1acb05a9..1a28258b 100644 --- a/tests/assets/configs/software_fix_duration.yaml +++ b/tests/assets/configs/software_fix_duration.yaml @@ -89,7 +89,7 @@ agents: - NONE tcp: - DNS - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index eb8103e8..27cfa240 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -120,7 +120,7 @@ agents: num_files: 1 num_nics: 2 include_num_access: false - include_nmne: true + include_nmne: false routers: - hostname: router_1 num_ports: 0 From 061509dffdd5d7f7fa4088bce2b082b487567f3b Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 17 Jul 2024 10:43:04 +0100 Subject: [PATCH 20/95] #2676: Further test fixes. --- src/primaite/game/game.py | 2 +- .../scenario_with_placeholders/scenario.yaml | 2 +- .../observations/test_nic_observations.py | 16 +++++++++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 3e129879..aca75b63 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency -from primaite.simulator.network.hardware.base import NodeOperatingState, NetworkInterface +from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server diff --git a/tests/assets/configs/scenario_with_placeholders/scenario.yaml b/tests/assets/configs/scenario_with_placeholders/scenario.yaml index 81848b2d..ef930a1a 100644 --- a/tests/assets/configs/scenario_with_placeholders/scenario.yaml +++ b/tests/assets/configs/scenario_with_placeholders/scenario.yaml @@ -44,7 +44,7 @@ agents: num_files: 1 num_nics: 1 include_num_access: false - include_nmne: true + include_nmne: false - type: LINKS label: LINKS diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 88dd2bd5..dfad8b59 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -9,9 +9,11 @@ from gymnasium import spaces from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import store_nmne_config from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser @@ -75,6 +77,18 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": [ + "DELETE", + "ENCRYPT", + ], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + NetworkInterface.nmne_config = store_nmne_config(nmne_config) + assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) @@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation): pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic ) simulation.pre_timestep(0) # apply timestep to whole sim From 43617340148051973518f72b45dce01cbb6f40cf Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 17 Jul 2024 17:50:55 +0100 Subject: [PATCH 21/95] #2676: Code review changes --- src/primaite/game/game.py | 5 +---- src/primaite/simulator/network/hardware/base.py | 4 ++-- src/primaite/simulator/network/nmne.py | 11 +++++------ 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index aca75b63..0c1b3192 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import NmneData, store_nmne_config +from primaite.simulator.network.nmne import store_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -110,9 +110,6 @@ class PrimaiteGame: self._reward_calculation_order: List[str] = [name for name in self.agents] """Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards.""" - self.nmne_config: NmneData = None - """ Config data from Number of Malicious Network Events.""" - def step(self): """ Perform one step of the simulation/agent loop. diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index f85d3f2e..50549389 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -19,7 +19,7 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.nmne import NmneData +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.applications.application import Application @@ -99,7 +99,7 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." - nmne_config: ClassVar[NmneData] = None + nmne_config: ClassVar[NMNEConfig] = None "A dataclass defining malicious network events to be captured." nmne: Dict = Field(default_factory=lambda: {}) diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index 947f27ac..431ec07d 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -1,15 +1,14 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from dataclasses import dataclass, field +from pydantic import BaseModel from typing import Dict, List -@dataclass -class NmneData: +class NMNEConfig(BaseModel): """Store all the information to perform NMNE operations.""" capture_nmne: bool = True """Indicates whether Malicious Network Events (MNEs) should be captured.""" - nmne_capture_keywords: List[str] = field(default_factory=list) + nmne_capture_keywords: List[str] = [] """List of keywords to identify malicious network events.""" capture_by_direction: bool = True """Captures should be organized by traffic direction (inbound/outbound).""" @@ -23,7 +22,7 @@ class NmneData: """Captures should be filtered and categorised based on specific keywords.""" -def store_nmne_config(nmne_config: Dict) -> NmneData: +def store_nmne_config(nmne_config: Dict) -> NMNEConfig: """ Store configuration for capturing Malicious Network Events (MNEs). @@ -51,4 +50,4 @@ def store_nmne_config(nmne_config: Dict) -> NmneData: if not isinstance(nmne_capture_keywords, list): nmne_capture_keywords = [] # Reset to empty list if the provided value is not a list - return NmneData(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) + return NMNEConfig(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) From 8702dc706797ad7970ba7a5bed1a7fbff7175c04 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 19 Jul 2024 10:34:32 +0100 Subject: [PATCH 22/95] #2735 - tidies up some oif the api, temporarily integrated login checks to ping for testing, added temp test --- .../simulator/network/hardware/base.py | 357 +++++++++++++++++- .../network/hardware/nodes/host/host_node.py | 5 +- .../simulator/system/core/software_manager.py | 8 +- .../system/services/access/user_manager.py | 185 --------- .../services/access/user_session_manager.py | 97 ----- .../simulator/system/services/service.py | 2 +- .../system/test_local_accounts.py | 37 ++ 7 files changed, 391 insertions(+), 300 deletions(-) create mode 100644 tests/integration_tests/system/test_local_accounts.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 64fad264..9e6784c5 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,7 +6,7 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, ClassVar, Dict, Optional, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field @@ -31,14 +31,13 @@ from primaite.simulator.network.nmne import ( ) from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.processes.process import Process -from primaite.simulator.system.services.access.user_manager import UserManager -from primaite.simulator.system.services.access.user_session_manager import UserSessionManager from primaite.simulator.system.services.service import Service from primaite.simulator.system.software import IOSoftware from primaite.utils.converters import convert_dict_enum_keys_to_enum_values @@ -796,6 +795,330 @@ class Link(SimComponent): self.current_load = 0.0 +class User(SimComponent): + """ + Represents a user in the PrimAITE system. + + :param username: The username of the user + :param password: The password of the user + :param disabled: Boolean flag indicating whether the user is disabled + :param is_admin: Boolean flag indicating whether the user has admin privileges + """ + + username: str + password: str + disabled: bool = False + is_admin: bool = False + + def describe_state(self) -> Dict: + """ + Returns a dictionary representing the current state of the user. + + :return: A dict containing the state of the user + """ + return self.model_dump() + + +class UserManager(Service): + """ + Manages users within the PrimAITE system, handling creation, authentication, and administration. + + :param users: A dictionary of all users by their usernames + :param admins: A dictionary of admin users by their usernames + :param disabled_admins: A dictionary of currently disabled admin users by their usernames + """ + + users: Dict[str, User] = Field(default_factory=dict) + admins: Dict[str, User] = Field(default_factory=dict) + disabled_admins: Dict[str, User] = Field(default_factory=dict) + + def __init__(self, **kwargs): + """ + Initializes a UserManager instanc. + + :param username: The username for the default admin user + :param password: The password for the default admin user + """ + kwargs["name"] = "UserManager" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) + self.start() + + def describe_state(self) -> Dict: + """ + Returns the state of the UserManager along with the number of users and admins. + + :return: A dict containing detailed state information + """ + state = super().describe_state() + state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)}) + return state + + def show(self, markdown: bool = False): + """ + Display the Users. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ + table = PrettyTable(["Username", "Admin", "Disabled"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} User Manager" + for user in self.users.values(): + table.add_row([user.username, user.is_admin, user.disabled]) + print(table.get_string(sortby="Username")) + + def _is_last_admin(self, username: str) -> bool: + return username in self.admins and len(self.admins) == 1 + + def add_user( + self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False + ) -> bool: + """ + Adds a new user to the system. + + :param username: The username for the new user + :param password: The password for the new user + :param is_admin: Flag indicating if the new user is an admin + :return: True if user was successfully added, False otherwise + """ + if not bypass_can_perform_action and not self._can_perform_action(): + return False + if username in self.users: + self.sys_log.info(f"{self.name}: Failed to create new user {username} as this user name already exists") + return False + user = User(username=username, password=password, is_admin=is_admin) + self.users[username] = user + if is_admin: + self.admins[username] = user + self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}") + return True + + def authenticate_user(self, username: str, password: str) -> Optional[User]: + """ + Authenticates a user's login attempt. + + :param username: The username of the user trying to log in + :param password: The password provided by the user + :return: The User object if authentication is successful, None otherwise + """ + if not self._can_perform_action(): + return None + user = self.users.get(username) + if user and not user.disabled and user.password == password: + self.sys_log.info(f"{self.name}: User authenticated: {username}") + return user + self.sys_log.info(f"{self.name}: Authentication failed for: {username}") + return None + + def change_user_password(self, username: str, current_password: str, new_password: str) -> bool: + """ + Changes a user's password. + + :param username: The username of the user changing their password + :param current_password: The current password of the user + :param new_password: The new password for the user + :return: True if the password was changed successfully, False otherwise + """ + if not self._can_perform_action(): + return False + user = self.users.get(username) + if user and user.password == current_password: + user.password = new_password + self.sys_log.info(f"{self.name}: Password changed for {username}") + return True + self.sys_log.info(f"{self.name}: Password change failed for {username}") + return False + + def disable_user(self, username: str) -> bool: + """ + Disables a user account, preventing them from logging in. + + :param username: The username of the user to disable + :return: True if the user was disabled successfully, False otherwise + """ + if not self._can_perform_action(): + return False + if username in self.users and not self.users[username].disabled: + if self._is_last_admin(username): + self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin") + return False + self.users[username].disabled = True + self.sys_log.info(f"{self.name}: User disabled: {username}") + if username in self.admins: + self.disabled_admins[username] = self.admins.pop(username) + return True + self.sys_log.info(f"{self.name}: Failed to disable user: {username}") + return False + + def enable_user(self, username: str) -> bool: + """ + Enables a previously disabled user account. + + :param username: The username of the user to enable + :return: True if the user was enabled successfully, False otherwise + """ + if username in self.users and self.users[username].disabled: + self.users[username].disabled = False + self.sys_log.info(f"{self.name}: User enabled: {username}") + if username in self.disabled_admins: + self.admins[username] = self.disabled_admins.pop(username) + return True + self.sys_log.info(f"{self.name}: Failed to enable user: {username}") + return False + + +class UserSession(SimComponent): + user: User + start_step: int + last_active_step: int + end_step: Optional[int] = None + local: bool = True + + @classmethod + def create(cls, user: User, timestep: int) -> UserSession: + return UserSession(user=user, start_step=timestep, last_active_step=timestep) + + def describe_state(self) -> Dict: + return self.model_dump() + + +class RemoteUserSession(UserSession): + remote_ip_address: IPV4Address + local: bool = False + + def describe_state(self) -> Dict: + state = super().describe_state() + state["remote_ip_address"] = str(self.remote_ip_address) + return state + + +class UserSessionManager(Service): + node: Node + local_session: Optional[UserSession] = None + remote_sessions: Dict[str, RemoteUserSession] = Field(default_factory=dict) + historic_sessions: List[UserSession] = Field(default_factory=list) + + local_session_timeout_steps: int = 30 + remote_session_timeout_steps: int = 5 + max_remote_sessions: int = 3 + + current_timestep: int = 0 + + def __init__(self, **kwargs): + """ + Initializes a UserSessionManager instance. + + :param username: The username for the default admin user + :param password: The password for the default admin user + """ + kwargs["name"] = "UserSessionManager" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) + self.start() + + def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False): + """Prints a table of the user sessions on the Node.""" + headers = ["Session ID", "Username", "Type", "Remote IP", "Start Step", "Step Last Active", "End Step"] + + if not include_session_id: + headers = headers[1:] + + table = PrettyTable(headers) + + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.node.hostname} User Sessions" + + def _add_session_to_table(user_session: UserSession): + session_type = "local" + remote_ip = "" + if isinstance(user_session, RemoteUserSession): + session_type = "remote" + remote_ip = str(user_session.remote_ip_address) + data = [ + user_session.uuid, + user_session.user.username, + session_type, + remote_ip, + user_session.start_step, + user_session.last_active_step, + user_session.end_step if user_session.end_step else "", + ] + if not include_session_id: + data = data[1:] + table.add_row(data) + + if self.local_session is not None: + _add_session_to_table(self.local_session) + + for user_session in self.remote_sessions.values(): + _add_session_to_table(user_session) + + if include_historic: + for user_session in self.historic_sessions: + _add_session_to_table(user_session) + + print(table.get_string(sortby="Step Last Active", reversesort=True)) + + def describe_state(self) -> Dict: + return super().describe_state() + + @property + def _user_manager(self) -> UserManager: + return self.software_manager.software["UserManager"] # noqa + + def pre_timestep(self, timestep: int) -> None: + """Apply any pre-timestep logic that helps make sure we have the correct observations.""" + self.current_timestep = timestep + if self.local_session: + if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep: + self._timeout_session(self.local_session) + + def _timeout_session(self, session: UserSession) -> None: + session.end_step = self.current_timestep + session_identity = session.user.username + if session.local: + self.local_session = None + session_type = "Local" + else: + self.remote_sessions.pop(session.uuid) + session_type = "Remote" + session_identity = f"{session_identity} {session.remote_ip_address}" + + self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") + + def login(self, username: str, password: str) -> Optional[str]: + if not self._can_perform_action(): + return None + user = self._user_manager.authenticate_user(username=username, password=password) + if user: + self.logout() + self.local_session = UserSession.create(user=user, timestep=self.current_timestep) + self.sys_log.info(f"{self.name}: User {user.username} logged in") + return self.local_session.uuid + else: + self.sys_log.info(f"{self.name}: Incorrect username or password") + + def logout(self): + if not self._can_perform_action(): + return False + if self.local_session: + session = self.local_session + session.end_step = self.current_timestep + self.historic_sessions.append(session) + self.local_session = None + self.sys_log.info(f"{self.name}: User {session.user.username} logged out") + + @property + def local_user_logged_in(self): + return self.local_session is not None + + class Node(SimComponent): """ A basic Node class that represents a node on the network. @@ -889,16 +1212,24 @@ class Node(SimComponent): super().__init__(**kwargs) self.session_manager.node = self self.session_manager.software_manager = self.software_manager - self.software_manager.install(UserSessionManager) + self.software_manager.install(UserSessionManager, node=self) self.software_manager.install(UserManager) + self.user_manager.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True) + self._install_system_software() - # @property - # def user_manager(self) -> UserManager: - # return self.software_manager.software["UserManager"] # noqa - # - # @property - # def _user_session_manager(self) -> UserSessionManager: - # return self.software_manager.software["UserSessionManager"] # noqa + @property + def user_manager(self) -> UserManager: + return self.software_manager.software["UserManager"] # noqa + + @property + def user_session_manager(self) -> UserSessionManager: + return self.software_manager.software["UserSessionManager"] # noqa + + def login(self, username: str, password: str) -> Optional[str]: + return self.user_session_manager.login(username, password) + + def logout(self): + return self.user_session_manager.logout() def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: """ @@ -1434,10 +1765,14 @@ class Node(SimComponent): :param pings: The number of pings to attempt, default is 4. :return: True if the ping is successful, otherwise False. """ + if not self.user_session_manager.local_user_logged_in: + return False if not isinstance(target_ip_address, IPv4Address): target_ip_address = IPv4Address(target_ip_address) if self.software_manager.icmp: + print("yes") return self.software_manager.icmp.ping(target_ip_address, pings) + print("no icmp") return False @abstractmethod diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 80f80a04..aac57e95 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -11,8 +11,6 @@ from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.nmap import NMAP from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.simulator.system.services.access.user_manager import UserManager -from primaite.simulator.system.services.access.user_session_manager import UserSessionManager from primaite.simulator.system.services.arp.arp import ARP, ARPPacket from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.icmp.icmp import ICMP @@ -318,10 +316,9 @@ class HostNode(Node): network_interface: Dict[int, NIC] = {} "The NICs on the node by port id." - def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, username: str, password: str, **kwargs): + def __init__(self, ip_address: IPV4Address, subnet_mask: IPV4Address, **kwargs): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) - self.user_manager.add_user(username=username, password=password, is_admin=True, bypass_can_perform_action=True) @property def nmap(self) -> Optional[NMAP]: diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index e2266c2d..c52e60ae 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -104,7 +104,7 @@ class SoftwareManager: return True return False - def install(self, software_class: Type[IOSoftwareClass]): + def install(self, software_class: Type[IOSoftwareClass], **install_kwargs) -> None: """ Install an Application or Service. @@ -116,7 +116,11 @@ class SoftwareManager: self.sys_log.warning(f"Cannot install {software_class} as it is already installed") return software = software_class( - software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server + software_manager=self, + sys_log=self.sys_log, + file_system=self.file_system, + dns_server=self.dns_server, + **install_kwargs, ) if isinstance(software, Application): software.install() diff --git a/src/primaite/simulator/system/services/access/user_manager.py b/src/primaite/simulator/system/services/access/user_manager.py index 09f8950e..be6c00e7 100644 --- a/src/primaite/simulator/system/services/access/user_manager.py +++ b/src/primaite/simulator/system/services/access/user_manager.py @@ -1,186 +1 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Dict, Optional - -from prettytable import MARKDOWN, PrettyTable -from pydantic import Field - -from primaite.simulator.core import SimComponent -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.service import Service - - -class User(SimComponent): - """ - Represents a user in the PrimAITE system. - - :param username: The username of the user - :param password: The password of the user - :param disabled: Boolean flag indicating whether the user is disabled - :param is_admin: Boolean flag indicating whether the user has admin privileges - """ - - username: str - password: str - disabled: bool = False - is_admin: bool = False - - def describe_state(self) -> Dict: - """ - Returns a dictionary representing the current state of the user. - - :return: A dict containing the state of the user - """ - return self.model_dump() - - -class UserManager(Service): - """ - Manages users within the PrimAITE system, handling creation, authentication, and administration. - - :param users: A dictionary of all users by their usernames - :param admins: A dictionary of admin users by their usernames - :param disabled_admins: A dictionary of currently disabled admin users by their usernames - """ - - users: Dict[str, User] = Field(default_factory=dict) - admins: Dict[str, User] = Field(default_factory=dict) - disabled_admins: Dict[str, User] = Field(default_factory=dict) - - def __init__(self, **kwargs): - """ - Initializes a UserManager instanc. - - :param username: The username for the default admin user - :param password: The password for the default admin user - """ - kwargs["name"] = "UserManager" - kwargs["port"] = Port.NONE - kwargs["protocol"] = IPProtocol.NONE - super().__init__(**kwargs) - self.start() - - def describe_state(self) -> Dict: - """ - Returns the state of the UserManager along with the number of users and admins. - - :return: A dict containing detailed state information - """ - state = super().describe_state() - state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)}) - return state - - def show(self, markdown: bool = False): - """ - Display the Users. - - :param markdown: Whether to display the table in Markdown format or not. Default is `False`. - """ - table = PrettyTable(["Username", "Admin", "Enabled"]) - if markdown: - table.set_style(MARKDOWN) - table.align = "l" - table.title = f"{self.sys_log.hostname} User Manager)" - for user in self.users.values(): - table.add_row([user.username, user.is_admin, user.disabled]) - print(table.get_string(sortby="Username")) - - def _is_last_admin(self, username: str) -> bool: - return username in self.admins and len(self.admins) == 1 - - def add_user( - self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False - ) -> bool: - """ - Adds a new user to the system. - - :param username: The username for the new user - :param password: The password for the new user - :param is_admin: Flag indicating if the new user is an admin - :return: True if user was successfully added, False otherwise - """ - if not bypass_can_perform_action and not self._can_perform_action(): - return False - if username in self.users: - return False - user = User(username=username, password=password, is_admin=is_admin) - self.users[username] = user - if is_admin: - self.admins[username] = user - self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}") - return True - - def authenticate_user(self, username: str, password: str) -> Optional[User]: - """ - Authenticates a user's login attempt. - - :param username: The username of the user trying to log in - :param password: The password provided by the user - :return: The User object if authentication is successful, None otherwise - """ - if not self._can_perform_action(): - return None - user = self.users.get(username) - if user and not user.disabled and user.password == password: - self.sys_log.info(f"{self.name}: User authenticated: {username}") - return user - self.sys_log.info(f"{self.name}: Authentication failed for: {username}") - return None - - def change_user_password(self, username: str, current_password: str, new_password: str) -> bool: - """ - Changes a user's password. - - :param username: The username of the user changing their password - :param current_password: The current password of the user - :param new_password: The new password for the user - :return: True if the password was changed successfully, False otherwise - """ - if not self._can_perform_action(): - return False - user = self.users.get(username) - if user and user.password == current_password: - user.password = new_password - self.sys_log.info(f"{self.name}: Password changed for {username}") - return True - self.sys_log.info(f"{self.name}: Password change failed for {username}") - return False - - def disable_user(self, username: str) -> bool: - """ - Disables a user account, preventing them from logging in. - - :param username: The username of the user to disable - :return: True if the user was disabled successfully, False otherwise - """ - if not self._can_perform_action(): - return False - if username in self.users and not self.users[username].disabled: - if self._is_last_admin(username): - self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin") - return False - self.users[username].disabled = True - self.sys_log.info(f"{self.name}: User disabled: {username}") - if username in self.admins: - self.disabled_admins[username] = self.admins.pop(username) - return True - self.sys_log.info(f"{self.name}: Failed to disable user: {username}") - return False - - def enable_user(self, username: str) -> bool: - """ - Enables a previously disabled user account. - - :param username: The username of the user to enable - :return: True if the user was enabled successfully, False otherwise - """ - if not self._can_perform_action(): - return False - if username in self.users and self.users[username].disabled: - self.users[username].disabled = False - self.sys_log.info(f"{self.name}: User enabled: {username}") - if username in self.disabled_admins: - self.admins[username] = self.disabled_admins.pop(username) - return True - self.sys_log.info(f"{self.name}: Failed to enable user: {username}") - return False diff --git a/src/primaite/simulator/system/services/access/user_session_manager.py b/src/primaite/simulator/system/services/access/user_session_manager.py index 03d2dd93..be6c00e7 100644 --- a/src/primaite/simulator/system/services/access/user_session_manager.py +++ b/src/primaite/simulator/system/services/access/user_session_manager.py @@ -1,98 +1 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from __future__ import annotations - -from datetime import datetime, timedelta -from typing import Dict, List, Optional -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from primaite.simulator.core import SimComponent -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.access.user_manager import User, UserManager -from primaite.simulator.system.services.service import Service -from primaite.utils.validators import IPV4Address - - -class UserSession(SimComponent): - user: User - start_step: int - last_active_step: int - end_step: Optional[int] = None - local: bool = True - - @classmethod - def create(cls, user: User, timestep: int) -> UserSession: - return UserSession(user=user, start_step=timestep, last_active_step=timestep) - def describe_state(self) -> Dict: - return self.model_dump() - - -class RemoteUserSession(UserSession): - remote_ip_address: IPV4Address - local: bool = False - - def describe_state(self) -> Dict: - state = super().describe_state() - state["remote_ip_address"] = str(self.remote_ip_address) - return state - - -class UserSessionManager(BaseModel): - node: - local_session: Optional[UserSession] = None - remote_sessions: Dict[str, RemoteUserSession] = Field(default_factory=dict) - historic_sessions: List[UserSession] = Field(default_factory=list) - - local_session_timeout_steps: int = 30 - remote_session_timeout_steps: int = 5 - max_remote_sessions: int = 3 - - current_timestep: int = 0 - - @property - def _user_manager(self) -> UserManager: - return self.software_manager.software["UserManager"] # noqa - - def pre_timestep(self, timestep: int) -> None: - """Apply any pre-timestep logic that helps make sure we have the correct observations.""" - self.current_timestep = timestep - if self.local_session: - if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep: - self._timeout_session(self.local_session) - - def _timeout_session(self, session: UserSession) -> None: - session.end_step = self.current_timestep - session_identity = session.user.username - if session.local: - self.local_session = None - session_type = "Local" - else: - self.remote_sessions.pop(session.uuid) - session_type = "Remote" - session_identity = f"{session_identity} {session.remote_ip_address}" - - self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") - - def login(self, username: str, password: str) -> Optional[str]: - if not self._can_perform_action(): - return None - user = self._user_manager.authenticate_user(username=username, password=password) - if user: - self.logout() - self.local_session = UserSession.create(user=user, timestep=self.current_timestep) - self.sys_log.info(f"{self.name}: User {user.username} logged in") - return self.local_session.uuid - else: - self.sys_log.info(f"{self.name}: Incorrect username or password") - - def logout(self): - if not self._can_perform_action(): - return False - if self.local_session: - session = self.local_session - session.end_step = self.current_timestep - self.historic_sessions.append(session) - self.local_session = None - self.sys_log.info(f"{self.name}: User {session.user.username} logged out") diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 4227175b..5adea6e7 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -46,7 +46,7 @@ class Service(IOSoftware): restart_countdown: Optional[int] = None "If currently restarting, how many timesteps remain until the restart is finished." - def __init__(self, **kwargs):c + def __init__(self, **kwargs): super().__init__(**kwargs) def _can_perform_action(self) -> bool: diff --git a/tests/integration_tests/system/test_local_accounts.py b/tests/integration_tests/system/test_local_accounts.py new file mode 100644 index 00000000..dbdbf857 --- /dev/null +++ b/tests/integration_tests/system/test_local_accounts.py @@ -0,0 +1,37 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server + + +def test_local_accounts_ping_temp(): + network = Network() + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Connect Computer and Server + network.connect(computer.network_interface[1], server.network_interface[1]) + + assert not computer.ping(server.network_interface[1].ip_address) + + computer.user_session_manager.login(username="admin", password="admin") + + assert computer.ping(server.network_interface[1].ip_address) From 9fb3790c1a731779e368207cc02b1fe1f587b5de Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 19 Jul 2024 11:10:57 +0100 Subject: [PATCH 23/95] #2726: Resolve pydantic validators PR comment --- src/primaite/simulator/network/nmne.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index 431ec07d..c9266fba 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -6,7 +6,7 @@ from typing import Dict, List class NMNEConfig(BaseModel): """Store all the information to perform NMNE operations.""" - capture_nmne: bool = True + capture_nmne: bool = False """Indicates whether Malicious Network Events (MNEs) should be captured.""" nmne_capture_keywords: List[str] = [] """List of keywords to identify malicious network events.""" @@ -42,12 +42,8 @@ def store_nmne_config(nmne_config: Dict) -> NMNEConfig: nmne_capture_keywords: List[str] = [] # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect capture_nmne = nmne_config.get("capture_nmne", False) - if not isinstance(capture_nmne, bool): - capture_nmne = True # Revert to default True if the provided value is not a boolean # Update the NMNE capture keywords, appending new keywords if provided nmne_capture_keywords += nmne_config.get("nmne_capture_keywords", []) - if not isinstance(nmne_capture_keywords, list): - nmne_capture_keywords = [] # Reset to empty list if the provided value is not a list return NMNEConfig(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) From 2104a7ec7d437153dd735c9d4fa95fffb87dc54a Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 19 Jul 2024 11:17:54 +0100 Subject: [PATCH 24/95] #2712 - Commit before merging in changes on dev --- .../simulator/network/protocols/ssh.py | 13 ++++++-- .../system/services/terminal/terminal.py | 32 +++++++++++++++++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 7be81982..86544813 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -56,6 +56,15 @@ class SSHConnectionMessage(IntEnum): SSH_MSG_CHANNEL_CLOSE = 87 """Closes the channel.""" +class SSHUserCredentials(DataPacket): + """Hold Username and Password in SSH Packets""" + + username: str = None + """Username for login""" + + password: str = None + """Password for login""" + class SSHPacket(DataPacket): """Represents an SSHPacket.""" @@ -64,8 +73,8 @@ class SSHPacket(DataPacket): connection_message: SSHConnectionMessage = None - ssh_command: Optional[str] = None # This is the request string + connection_uuid: Optional[str] = None # The connection uuid used to validate the session ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse - user_account: Optional[Dict] = None # The user account we will use to login if we do not have a current connection. + ssh_command: Optional[str] = None # This is the request string \ No newline at end of file diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 589492ba..3cf9fc0d 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -7,8 +7,8 @@ from uuid import uuid4 from pydantic import BaseModel -from primaite.interface.request import RequestResponse -from primaite.simulator.core import RequestManager +from primaite.interface.request import RequestFormat, RequestResponse +from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol @@ -96,7 +96,11 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" # TODO: Expand with a login validator? + + _login_valid = Terminal._LoginValidator(terminal=self) + rm = super()._init_request_manager() + rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid)) return rm def _validate_login(self, user_account: Optional[str]) -> bool: @@ -109,10 +113,32 @@ class Terminal(Service): else: return True + class _LoginValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only allow them through if the + User is logged into the Terminal. + + Login is required before making use of the Terminal. + """ + + terminal: Terminal + """Save a reference to the Terminal instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the Terminal has valid login credentials""" + return self.terminal.login_status + + @property + def fail_message(self) -> str: + """Message that is reported when a request is rejected by this validator""" + return ("Cannot perform request on terminal as not logged in.") + + # %% Inbound def _generate_connection_uuid(self) -> str: """Generate a unique connection ID.""" + # This might not be needed given user_manager.login() returns a UUID. return str(uuid4()) def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: @@ -136,7 +162,7 @@ class Terminal(Service): payload="login", transport_message=transport_message, connection_message=connection_message ) - self.sys_log.debug(f"Sending login request to {dest_ip_address}") + self.sys_log.info(f"Sending login request to {dest_ip_address}") self.send(payload=payload, dest_ip_address=dest_ip_address) def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: From 155562cb683fa6216eae02598528e2662f8ce0f7 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 19 Jul 2024 11:18:17 +0100 Subject: [PATCH 25/95] #2712 - Commit before merging in changes on dev --- src/primaite/simulator/network/protocols/ssh.py | 3 ++- .../simulator/system/services/terminal/terminal.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 86544813..af1c550a 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -56,6 +56,7 @@ class SSHConnectionMessage(IntEnum): SSH_MSG_CHANNEL_CLOSE = 87 """Closes the channel.""" + class SSHUserCredentials(DataPacket): """Hold Username and Password in SSH Packets""" @@ -77,4 +78,4 @@ class SSHPacket(DataPacket): ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse - ssh_command: Optional[str] = None # This is the request string \ No newline at end of file + ssh_command: Optional[str] = None # This is the request string diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 3cf9fc0d..9dd40edc 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -100,7 +100,12 @@ class Terminal(Service): _login_valid = Terminal._LoginValidator(terminal=self) rm = super()._init_request_manager() - rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid)) + rm.add_request( + "login", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid + ), + ) return rm def _validate_login(self, user_account: Optional[str]) -> bool: @@ -127,12 +132,11 @@ class Terminal(Service): def __call__(self, request: RequestFormat, context: Dict) -> bool: """Return whether the Terminal has valid login credentials""" return self.terminal.login_status - + @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator""" - return ("Cannot perform request on terminal as not logged in.") - + return "Cannot perform request on terminal as not logged in." # %% Inbound From e4ade6ba5484f70d2b9f1e5917513d4d698823eb Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 19 Jul 2024 12:02:43 +0100 Subject: [PATCH 26/95] #2676: Merge nmne.py with io.py --- src/primaite/game/game.py | 2 +- src/primaite/session/io.py | 46 +++++++++++++++++ .../simulator/network/hardware/base.py | 2 +- src/primaite/simulator/network/nmne.py | 49 ------------------- .../observations/test_nic_observations.py | 2 +- .../network/test_capture_nmne.py | 2 +- 6 files changed, 50 insertions(+), 53 deletions(-) delete mode 100644 src/primaite/simulator/network/nmne.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 0c1b3192..cd0180db 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -16,6 +16,7 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort +from primaite.session.io import store_nmne_config from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState @@ -26,7 +27,6 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import store_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 78d7cb3c..2d0d5897 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -131,3 +131,49 @@ class PrimaiteIO: new = cls(settings=cls.Settings(**config)) return new + + +class NMNEConfig(BaseModel): + """Store all the information to perform NMNE operations.""" + + capture_nmne: bool = False + """Indicates whether Malicious Network Events (MNEs) should be captured.""" + nmne_capture_keywords: List[str] = [] + """List of keywords to identify malicious network events.""" + capture_by_direction: bool = True + """Captures should be organized by traffic direction (inbound/outbound).""" + capture_by_ip_address: bool = False + """Captures should be organized by source or destination IP address.""" + capture_by_protocol: bool = False + """Captures should be organized by network protocol (e.g., TCP, UDP).""" + capture_by_port: bool = False + """Captures should be organized by source or destination port.""" + capture_by_keyword: bool = False + """Captures should be filtered and categorised based on specific keywords.""" + + +def store_nmne_config(nmne_config: Dict) -> NMNEConfig: + """ + Store configuration for capturing Malicious Network Events (MNEs). + + This function updates global settings related to NMNE capture, including whether to capture + NMNEs and what keywords to use for identifying NMNEs. + + The function ensures that the settings are updated only if they are provided in the + `nmne_config` dictionary, and maintains type integrity by checking the types of the provided + values. + + :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys + include: + "capture_nmne" (bool) to indicate whether NMNEs should be captured; + "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. + :rvar dataclass with data read from config file. + """ + nmne_capture_keywords: List[str] = [] + # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect + capture_nmne = nmne_config.get("capture_nmne", False) + + # Update the NMNE capture keywords, appending new keywords if provided + nmne_capture_keywords += nmne_config.get("nmne_capture_keywords", []) + + return NMNEConfig(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 50549389..aafdbe5c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -14,12 +14,12 @@ from pydantic import BaseModel, Field from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse +from primaite.session.io import NMNEConfig from primaite.simulator import SIM_OUTPUT from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.applications.application import Application diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py deleted file mode 100644 index c9266fba..00000000 --- a/src/primaite/simulator/network/nmne.py +++ /dev/null @@ -1,49 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from pydantic import BaseModel -from typing import Dict, List - - -class NMNEConfig(BaseModel): - """Store all the information to perform NMNE operations.""" - - capture_nmne: bool = False - """Indicates whether Malicious Network Events (MNEs) should be captured.""" - nmne_capture_keywords: List[str] = [] - """List of keywords to identify malicious network events.""" - capture_by_direction: bool = True - """Captures should be organized by traffic direction (inbound/outbound).""" - capture_by_ip_address: bool = False - """Captures should be organized by source or destination IP address.""" - capture_by_protocol: bool = False - """Captures should be organized by network protocol (e.g., TCP, UDP).""" - capture_by_port: bool = False - """Captures should be organized by source or destination port.""" - capture_by_keyword: bool = False - """Captures should be filtered and categorised based on specific keywords.""" - - -def store_nmne_config(nmne_config: Dict) -> NMNEConfig: - """ - Store configuration for capturing Malicious Network Events (MNEs). - - This function updates global settings related to NMNE capture, including whether to capture - NMNEs and what keywords to use for identifying NMNEs. - - The function ensures that the settings are updated only if they are provided in the - `nmne_config` dictionary, and maintains type integrity by checking the types of the provided - values. - - :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys - include: - "capture_nmne" (bool) to indicate whether NMNEs should be captured; - "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. - :rvar dataclass with data read from config file. - """ - nmne_capture_keywords: List[str] = [] - # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect - capture_nmne = nmne_config.get("capture_nmne", False) - - # Update the NMNE capture keywords, appending new keywords if provided - nmne_capture_keywords += nmne_config.get("nmne_capture_keywords", []) - - return NMNEConfig(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index dfad8b59..7f86d26d 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -9,11 +9,11 @@ from gymnasium import spaces from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame +from primaite.session.io import store_nmne_config from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.nmne import store_nmne_config from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index f6e4c685..b4162e58 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,9 +1,9 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.io import store_nmne_config from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.nmne import store_nmne_config from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection From 82a11b8b85c28bf90bbac3169f196b09fbfb8c4b Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 19 Jul 2024 12:54:01 +0100 Subject: [PATCH 27/95] #2676: Updated doc strings --- src/primaite/session/io.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 2d0d5897..c634e835 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -156,18 +156,17 @@ def store_nmne_config(nmne_config: Dict) -> NMNEConfig: """ Store configuration for capturing Malicious Network Events (MNEs). - This function updates global settings related to NMNE capture, including whether to capture - NMNEs and what keywords to use for identifying NMNEs. + This function updates settings related to NMNE capture, stored in NMNEConfig including whether + to capture NMNEs and the keywords to use for identifying NMNEs. The function ensures that the settings are updated only if they are provided in the - `nmne_config` dictionary, and maintains type integrity by checking the types of the provided - values. + `nmne_config` dictionary, and maintains type integrity by relying on pydantic validators. :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include: "capture_nmne" (bool) to indicate whether NMNEs should be captured; "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. - :rvar dataclass with data read from config file. + :rvar class with data read from config file. """ nmne_capture_keywords: List[str] = [] # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect From 3c590a873340a99f4d47bf6b693d1b9716922d43 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 22 Jul 2024 09:58:09 +0100 Subject: [PATCH 28/95] #2712 - Commit before changing branches --- .../system/services/terminal/terminal.py | 153 +++++------------- .../_system/_services/test_terminal.py | 6 +- 2 files changed, 47 insertions(+), 112 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 9dd40edc..039fbeb3 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -17,6 +17,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState + +# TODO: This might not be needed now? class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. @@ -52,9 +54,6 @@ class TerminalClientConnection(BaseModel): class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - user_account: Optional[str] = None - "The User Account used for login" - is_connected: bool = False "Boolean Value for whether connected" @@ -64,8 +63,6 @@ class Terminal(Service): operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING """Initial Operating State""" - user_connections: Dict[str, TerminalClientConnection] = {} - """List of authenticated connected users""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -85,38 +82,24 @@ class Terminal(Service): :rtype: Dict """ state = super().describe_state() - - state.update({"hostname": self.name}) return state def apply_request(self, request: List[str | int | float | Dict], context: Dict | None = None) -> RequestResponse: - """Apply Temrinal Request.""" + """Apply Terminal Request.""" return super().apply_request(request, context) def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" - # TODO: Expand with a login validator? - _login_valid = Terminal._LoginValidator(terminal=self) rm = super()._init_request_manager() - rm.add_request( - "login", - request_type=RequestType( - func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid - ), - ) + rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid)) return rm - def _validate_login(self, user_account: Optional[str]) -> bool: + def _validate_login(self, connection_id: str) -> bool: """Validate login credentials are valid.""" - # TODO: Interact with UserManager to check user_account details - if len(self.user_connections) == 0: - # No current connections - self.sys_log.warning("Login Required!") - return False - else: - return True + return self.parent.UserSessionManager.validate_remote_session_uuid(connection_id) + class _LoginValidator(RequestPermissionValidator): """ @@ -132,77 +115,64 @@ class Terminal(Service): def __call__(self, request: RequestFormat, context: Dict) -> bool: """Return whether the Terminal has valid login credentials""" return self.terminal.login_status - + @property def fail_message(self) -> str: """Message that is reported when a request is rejected by this validator""" - return "Cannot perform request on terminal as not logged in." + return ("Cannot perform request on terminal as not logged in.") + # %% Inbound - def _generate_connection_uuid(self) -> str: - """Generate a unique connection ID.""" - # This might not be needed given user_manager.login() returns a UUID. - return str(uuid4()) - - def login(self, dest_ip_address: IPv4Address, **kwargs) -> bool: + def login(self, username: str, password: str, ip_address: Optional[IPv4Address]=None) -> bool: """Process User request to login to Terminal. :param dest_ip_address: The IP address of the node we want to connect to. + :param username: The username credential. + :param password: The user password component of credentials. :return: True if successful, False otherwise. """ if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning("Cannot process login as service is not running") return False - if self.connection_uuid in self.user_connections: - self.sys_log.debug("User authentication passed") + + # need to determine if this is a local or remote login + + if ip_address: + # ip_address has been given for remote login + return self._send_remote_login(username=username, password=password, ip_address=ip_address) + + return self._process_local_login(username=username, password=password) + + + def _process_local_login(self, username: str, password: str) -> bool: + """Local session login to terminal.""" + self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + if self.connection_uuid: + self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") return True else: - # Need to send a login request - # TODO: Refactor with UserManager changes to provide correct credentials and validate. - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN - payload: SSHPacket = SSHPacket( - payload="login", transport_message=transport_message, connection_message=connection_message - ) + self.sys_log.warning("Login failed, incorrect Username or Password") + return False - self.sys_log.info(f"Sending login request to {dest_ip_address}") - self.send(payload=payload, dest_ip_address=dest_ip_address) + def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool: + """Attempt to login to a remote terminal.""" + pass - def _ssh_process_login(self, dest_ip_address: IPv4Address, user_account: dict, **kwargs) -> bool: - """Processes the login attempt. Returns a bool which either rejects the login or accepts it.""" - # we assume that the login fails unless we meet all the criteria. - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_FAILURE - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_FAILED - # Hard coded at current - replace with another method to handle local accounts. - if user_account == "Username: placeholder, Password: placeholder": # hardcoded - self.connection_uuid = self._generate_connection_uuid() - if not self.add_connection(connection_id=self.connection_uuid): - self.sys_log.warning( - f"{self.name}: Connect request for {dest_ip_address} declined. Service is at capacity." - ) - return False - else: - self.sys_log.info(f"{self.name}: Connect request for ID: {self.connection_uuid} authorised") - transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS - connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN_CONFIRMATION - new_connection = TerminalClientConnection( - parent_node=self.software_manager.node, - connection_id=self.connection_uuid, - dest_ip_address=dest_ip_address, - ) - self.user_connections[self.connection_uuid] = new_connection - self.is_connected = True - payload: SSHPacket = SSHPacket(transport_message=transport_message, connection_message=connection_message) + def _process_remote_login(self, username: str, password: str, ip_address:IPv4Address) -> bool: + """Processes a remote terminal requesting to login to this terminal.""" + self.connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + if self.connection_uuid: + # Send uuid to remote + self.sys_log.info(f"Remote login authorised, connection ID {self.connection_uuid} for {username} on {ip_address}") + # send back to origin. + return True + else: + self.sys_log.warning("Login failed, incorrect Username or Password") + return False - self.send(payload=payload, dest_ip_address=dest_ip_address) - return True - - def _ssh_process_logoff(self, session_id: str, *args, **kwargs) -> bool: - """Process the logoff attempt. Return a bool if succesful or unsuccessful.""" - # TODO: Should remove def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: """Receive Payload and process for a response.""" @@ -213,12 +183,9 @@ class Terminal(Service): self.sys_log.warning("Cannot process message as not running") return False - self.sys_log.debug(f"Received payload: {payload} from session: {session_id}") - if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: connection_id = kwargs["connection_id"] dest_ip_address = kwargs["dest_ip_address"] - self._ssh_process_logoff(session_id=session_id) self.disconnect(dest_ip_address=dest_ip_address) self.sys_log.debug(f"Disconnecting {connection_id}") # We need to close on the other machine as well @@ -240,38 +207,6 @@ class Terminal(Service): return True # %% Outbound - def _ssh_remote_login(self, dest_ip_address: IPv4Address, user_account: Optional[dict] = None) -> bool: - """Remote login to terminal via SSH.""" - if not user_account: - # TODO: Generic hardcoded info, will need to be updated with UserManager. - user_account = "Username: placeholder, Password: placeholder" - # something like self.user_manager.get_user_details ? - - # Implement SSHPacket class - payload: SSHPacket = SSHPacket( - transport_message=SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST, - connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, - user_account=user_account, - ) - if self.send(payload=payload, dest_ip_address=dest_ip_address): - if payload.connection_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.info(f"{self.name} established an ssh connection with {dest_ip_address}") - # Need to confirm if self.uuid is correct. - self.add_connection(self, connection_id=self.uuid, session_id=self.session_id) - return True - else: - self.sys_log.error("Login Failed. Incorrect credentials provided.") - return False - else: - self.sys_log.error("Login Failed. Incorrect credentials provided.") - return False - - def check_connection(self, connection_id: str) -> bool: - """Check whether the connection is valid.""" - if self.is_connected: - return self.send(dest_ip_address=self.dest_ip_address, connection_id=connection_id) - else: - return False def disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from remote connection. diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 6b0365ce..673b11a3 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -80,7 +80,7 @@ def test_terminal_fail_when_closed(basic_network): terminal.operating_state = ServiceOperatingState.STOPPED - assert terminal.login(dest_ip_address="192.168.0.11") is False + assert terminal.login(ip_address="192.168.0.11") is False def test_terminal_disconnect(basic_network): @@ -91,7 +91,7 @@ def test_terminal_disconnect(basic_network): assert terminal.is_connected is False - terminal.login(dest_ip_address="192.168.0.11") + terminal.login(ip_address="192.168.0.11") assert terminal.is_connected is True @@ -108,7 +108,7 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_a.login(dest_ip_address="192.168.0.11") # login to computer_b + terminal_a.login(ip_address="192.168.0.11") # login to computer_b assert terminal_a.is_connected is True From a7f9e4502edd85a905901d30ef7b20f0d114f33f Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 23 Jul 2024 15:18:20 +0100 Subject: [PATCH 29/95] #2712 - Updates to the login logic and fixing resultant test failures. Updates to terminal.rst and ssh.py --- .../system/services/terminal.rst | 26 +-- .../simulator/network/protocols/ssh.py | 24 ++- .../system/services/terminal/terminal.py | 148 +++++++++++------- .../_system/_services/test_terminal.py | 27 ++-- 4 files changed, 146 insertions(+), 79 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index afa79c0a..49dc941b 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -5,9 +5,16 @@ .. _Terminal: Terminal -######## +======== -The ``Terminal`` provides a generic terminal simulation, by extending the base Service class +The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment. + + +Overview +-------- + +The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically +installed on Nodes when they are instantiated. Key capabilities ================ @@ -17,21 +24,22 @@ Key capabilities - Simulates common Terminal commands - Leverages the Service base class for install/uninstall, status tracking etc. - Usage ===== - - Install on a node via the ``SoftwareManager`` to start the Terminal - - Terminal Clients connect, execute commands and disconnect. + - Pre-Installs on any `HostNode` component. See the below code example of how to access the terminal. + - Terminal Clients connect, execute commands and disconnect from remote components. + - Ensures that users are logged in to the component before executing any commands. - Service runs on SSH port 22 by default. Implementation ============== -- Manages SSH commands -- Ensures User login before sending commands -- Processes SSH commands -- Returns results in a ** format. +The terminal takes inspiration from the `Database Client` and `Database Service` classes, and leverages the `UserSessionManager` +to provide User Credential authentication when receiving/processing commands. + +Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating +the passing of requests to both. Python diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index af1c550a..5eb181a6 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -1,7 +1,8 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import IntEnum -from typing import Dict, Optional +from ipaddress import IPv4Address +from typing import Optional from primaite.interface.request import RequestResponse from primaite.simulator.network.protocols.packet import DataPacket @@ -58,21 +59,32 @@ class SSHConnectionMessage(IntEnum): class SSHUserCredentials(DataPacket): - """Hold Username and Password in SSH Packets""" + """Hold Username and Password in SSH Packets.""" - username: str = None + username: str """Username for login""" - password: str = None + password: str """Password for login""" class SSHPacket(DataPacket): """Represents an SSHPacket.""" - transport_message: SSHTransportMessage = None + sender_ip_address: IPv4Address + """Sender IP Address""" - connection_message: SSHConnectionMessage = None + target_ip_address: IPv4Address + """Target IP Address""" + + transport_message: SSHTransportMessage + """Message Transport Type""" + + connection_message: SSHConnectionMessage + """Message Connection Status""" + + user_account: Optional[SSHUserCredentials] = None + """User Account Credentials if passed""" connection_uuid: Optional[str] = None # The connection uuid used to validate the session diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 039fbeb3..7f37bc28 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -3,30 +3,33 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Dict, List, Optional -from uuid import uuid4 from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType from primaite.simulator.network.hardware.base import Node -from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.network.protocols.ssh import ( + SSHConnectionMessage, + SSHPacket, + SSHTransportMessage, + SSHUserCredentials, +) from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState - # TODO: This might not be needed now? class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. - This class is used to record current User Connections within the Terminal class. + This class is used to record current remote User Connections to the Terminal class. """ - parent_node: Node # Technically I think this should be HostNode, but that causes a circular import. + parent_node: Node # Technically should be HostNode but this causes circular import error. """The parent Node that this connection was created on.""" is_active: bool = True @@ -35,6 +38,9 @@ class TerminalClientConnection(BaseModel): _dest_ip_address: IPv4Address """Destination IP address of connection""" + _connection_uuid: str = None + """Connection UUID""" + @property def dest_ip_address(self) -> Optional[IPv4Address]: """Destination IP Address.""" @@ -48,7 +54,7 @@ class TerminalClientConnection(BaseModel): def disconnect(self): """Disconnect the connection.""" if self.client and self.is_active: - self.client._disconnect(self.connection_id) # noqa + self.client._disconnect(self._connection_uuid) # noqa class Terminal(Service): @@ -63,6 +69,10 @@ class Terminal(Service): operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING """Initial Operating State""" + remote_connection: TerminalClientConnection = None + + parent: Node + """Parent component the terminal service is installed on.""" def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -93,18 +103,21 @@ class Terminal(Service): _login_valid = Terminal._LoginValidator(terminal=self) rm = super()._init_request_manager() - rm.add_request("login", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid)) + rm.add_request( + "login", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid + ), + ) return rm - def _validate_login(self, connection_id: str) -> bool: + def _validate_login(self) -> bool: """Validate login credentials are valid.""" - return self.parent.UserSessionManager.validate_remote_session_uuid(connection_id) - + return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid) class _LoginValidator(RequestPermissionValidator): """ - When requests come in, this validator will only allow them through if the - User is logged into the Terminal. + When requests come in, this validator will only allow them through if the User is logged into the Terminal. Login is required before making use of the Terminal. """ @@ -113,18 +126,17 @@ class Terminal(Service): """Save a reference to the Terminal instance.""" def __call__(self, request: RequestFormat, context: Dict) -> bool: - """Return whether the Terminal has valid login credentials""" - return self.terminal.login_status - + """Return whether the Terminal has valid login credentials.""" + return self.terminal.is_connected + @property def fail_message(self) -> str: - """Message that is reported when a request is rejected by this validator""" - return ("Cannot perform request on terminal as not logged in.") - + """Message that is reported when a request is rejected by this validator.""" + return "Cannot perform request on terminal as not logged in." # %% Inbound - def login(self, username: str, password: str, ip_address: Optional[IPv4Address]=None) -> bool: + def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: """Process User request to login to Terminal. :param dest_ip_address: The IP address of the node we want to connect to. @@ -136,15 +148,12 @@ class Terminal(Service): self.sys_log.warning("Cannot process login as service is not running") return False - # need to determine if this is a local or remote login - if ip_address: - # ip_address has been given for remote login + # if ip_address has been provided, we assume we are logging in to a remote terminal. return self._send_remote_login(username=username, password=password, ip_address=ip_address) return self._process_local_login(username=username, password=password) - def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) @@ -157,25 +166,54 @@ class Terminal(Service): def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool: """Attempt to login to a remote terminal.""" - pass + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password) + payload: SSHPacket = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + user_account=user_account, + target_ip_address=ip_address, + sender_ip_address=self.parent.network_interface[1].ip_address, + ) + self.sys_log.info(f"Sending remote login request to {ip_address}") + return self.send(payload=payload, dest_ip_address=ip_address) - def _process_remote_login(self, username: str, password: str, ip_address:IPv4Address) -> bool: + def _process_remote_login(self, payload: SSHPacket) -> bool: """Processes a remote terminal requesting to login to this terminal.""" + username: str = payload.user_account.username + password: str = payload.user_account.password self.connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") + if self.connection_uuid: # Send uuid to remote - self.sys_log.info(f"Remote login authorised, connection ID {self.connection_uuid} for {username} on {ip_address}") - # send back to origin. + self.sys_log.info( + f"Remote login authorised, connection ID {self.connection_uuid} for " + f"{username} on {payload.sender_ip_address}" + ) + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + payload = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + connection_uuid=self.connection_uuid, + sender_ip_address=self.parent.network_interface[1].ip_address, + target_ip_address=payload.sender_ip_address, + ) + self.send(payload=payload, dest_ip_address=payload.target_ip_address) return True else: + # UserSessionManager has returned None self.sys_log.warning("Login failed, incorrect Username or Password") return False - - def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: + def receive(self, payload: SSHPacket, **kwargs) -> bool: """Receive Payload and process for a response.""" + self.sys_log.debug(f"Received payload: {payload}") + if not isinstance(payload, SSHPacket): return False @@ -184,6 +222,7 @@ class Terminal(Service): return False if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: + # Close the channel connection_id = kwargs["connection_id"] dest_ip_address = kwargs["dest_ip_address"] self.disconnect(dest_ip_address=dest_ip_address) @@ -191,12 +230,13 @@ class Terminal(Service): # We need to close on the other machine as well elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - # validate login - user_account = "Username: placeholder, Password: placeholder" - self._ssh_process_login(dest_ip_address="192.168.0.10", user_account=user_account) + """Login Request Received.""" + self._process_remote_login(payload=payload) + self.sys_log.info("User Auth Success!") elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.debug("Login Successful") + self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") + self.connection_uuid = payload.connection_uuid self.is_connected = True return True @@ -208,6 +248,26 @@ class Terminal(Service): # %% Outbound + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: + """Disconnect from the remote.""" + if not self.is_connected: + self.sys_log.warning("Not currently connected to remote") + return False + + if not self.remote_connection: + self.sys_log.warning("No remote connection present") + return False + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": self.remote_connection._connection_uuid}, + dest_ip_address=dest_ip_address, + dest_port=self.port, + ) + self.connection_uuid = None + self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") + return True + def disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from remote connection. @@ -217,28 +277,6 @@ class Terminal(Service): self._disconnect(dest_ip_address=dest_ip_address) self.is_connected = False - def _disconnect(self, dest_ip_address: IPv4Address) -> bool: - if not self.is_connected: - return False - - if len(self.user_connections) == 0: - self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.") - return False - if not self.user_connections.get(self.connection_uuid): - return False - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": self.connection_uuid}, - dest_ip_address=dest_ip_address, - dest_port=self.port, - ) - connection = self.user_connections.pop(self.connection_uuid) - - connection.is_active = False - - self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") - return True - def send( self, payload: SSHPacket, diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 673b11a3..65346b45 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -62,14 +62,17 @@ def test_terminal_send(basic_network): network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") payload: SSHPacket = SSHPacket( payload="Test_Payload", transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + sender_ip_address=computer_a.network_interface[1].ip_address, + target_ip_address=computer_b.network_interface[1].ip_address, ) - assert terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) def test_terminal_fail_when_closed(basic_network): @@ -77,27 +80,33 @@ def test_terminal_fail_when_closed(basic_network): network: Network = basic_network computer: Computer = network.get_node_by_hostname("node_a") terminal: Terminal = computer.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") terminal.operating_state = ServiceOperatingState.STOPPED - assert terminal.login(ip_address="192.168.0.11") is False + assert ( + terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) + is False + ) def test_terminal_disconnect(basic_network): """Terminal should set is_connected to false on disconnect""" network: Network = basic_network - computer: Computer = network.get_node_by_hostname("node_a") - terminal: Terminal = computer.software_manager.software.get("Terminal") + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") - assert terminal.is_connected is False + assert terminal_a.is_connected is False - terminal.login(ip_address="192.168.0.11") + terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) - assert terminal.is_connected is True + assert terminal_a.is_connected is True - terminal.disconnect(dest_ip_address="192.168.0.11") + terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address) - assert terminal.is_connected is False + assert terminal_a.is_connected is False def test_terminal_ignores_when_off(basic_network): From a36e34ee1d84175e5efb6ad1461797dc169beda4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 24 Jul 2024 08:31:24 +0100 Subject: [PATCH 30/95] #2712 - Prepping ahead of raising PR. --- .../simulation_components/system/services/terminal.rst | 2 +- .../simulator/system/services/terminal/terminal.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index 49dc941b..4d1285d1 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -21,7 +21,7 @@ Key capabilities - Authenticates User connection by maintaining an active User account. - Ensures packets are matched to an existing session - - Simulates common Terminal commands + - Simulates common Terminal processes/commands. - Leverages the Service base class for install/uninstall, status tracking etc. Usage diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 7f37bc28..d3ff47da 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -71,9 +71,6 @@ class Terminal(Service): remote_connection: TerminalClientConnection = None - parent: Node - """Parent component the terminal service is installed on.""" - def __init__(self, **kwargs): kwargs["name"] = "Terminal" kwargs["port"] = Port.SSH @@ -196,14 +193,14 @@ class Terminal(Service): ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA - payload = SSHPacket( + return_payload = SSHPacket( transport_message=transport_message, connection_message=connection_message, connection_uuid=self.connection_uuid, sender_ip_address=self.parent.network_interface[1].ip_address, target_ip_address=payload.sender_ip_address, ) - self.send(payload=payload, dest_ip_address=payload.target_ip_address) + self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) return True else: # UserSessionManager has returned None @@ -232,7 +229,6 @@ class Terminal(Service): elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: """Login Request Received.""" self._process_remote_login(payload=payload) - self.sys_log.info("User Auth Success!") elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") From 1cb6ce02e001a4da1bac128b0f9fe282cba45402 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 24 Jul 2024 10:38:12 +0100 Subject: [PATCH 31/95] #2712 - Correcting the use of TerminalClientConnection for remote connections. Terminal should hold a list of active remote connections to itself with connection uuid for validation --- .../system/services/terminal/terminal.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index d3ff47da..9a71b63a 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -3,6 +3,7 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Dict, List, Optional +from uuid import uuid4 from pydantic import BaseModel @@ -21,7 +22,6 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState -# TODO: This might not be needed now? class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. @@ -69,7 +69,7 @@ class Terminal(Service): operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING """Initial Operating State""" - remote_connection: TerminalClientConnection = None + remote_connection: Dict[str, TerminalClientConnection] = {} def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -110,7 +110,8 @@ class Terminal(Service): def _validate_login(self) -> bool: """Validate login credentials are valid.""" - return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid) + # return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid) + return True class _LoginValidator(RequestPermissionValidator): """ @@ -153,7 +154,8 @@ class Terminal(Service): def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" - self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + # self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + self.connection_uuid = str(uuid4()) if self.connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") return True @@ -182,10 +184,11 @@ class Terminal(Service): """Processes a remote terminal requesting to login to this terminal.""" username: str = payload.user_account.username password: str = payload.user_account.password - self.connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") + # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + connection_uuid = str(uuid4()) - if self.connection_uuid: + if connection_uuid: # Send uuid to remote self.sys_log.info( f"Remote login authorised, connection ID {self.connection_uuid} for " @@ -196,11 +199,18 @@ class Terminal(Service): return_payload = SSHPacket( transport_message=transport_message, connection_message=connection_message, - connection_uuid=self.connection_uuid, + connection_uuid=connection_uuid, sender_ip_address=self.parent.network_interface[1].ip_address, target_ip_address=payload.sender_ip_address, ) self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) + + self.remote_connection[connection_uuid] = TerminalClientConnection( + parent_node=self.software_manager.node, + _dest_ip_address=payload.sender_ip_address, + connection_uuid=connection_uuid, + ) + return True else: # UserSessionManager has returned None From a0e675a09a26116a335611e7f204d9ea93df88c6 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 24 Jul 2024 11:20:01 +0100 Subject: [PATCH 32/95] #2712 - Minor changes to login Validator --- src/primaite/simulator/system/services/terminal/terminal.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 9a71b63a..f01b44a2 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -101,9 +101,9 @@ class Terminal(Service): rm = super()._init_request_manager() rm.add_request( - "login", + "send", request_type=RequestType( - func=lambda request, context: RequestResponse.from_bool(self._validate_login()), validator=_login_valid + func=lambda request, context: RequestResponse.from_bool(self.send()), validator=_login_valid ), ) return rm @@ -124,7 +124,7 @@ class Terminal(Service): """Save a reference to the Terminal instance.""" def __call__(self, request: RequestFormat, context: Dict) -> bool: - """Return whether the Terminal has valid login credentials.""" + """Return whether the Terminal is connected.""" return self.terminal.is_connected @property From d0c8aeae301baa4d5f56506181a42955ff77b94d Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 24 Jul 2024 17:08:18 +0100 Subject: [PATCH 33/95] #2735 - implemented remote logins. Added action remote sessions to UserSessionManager describe_state. Added suite of tests for UserSessionManager logins --- .../simulator/network/hardware/base.py | 95 +++++-- .../system/test_local_accounts.py | 37 --- .../test_user_session_manager_logins.py | 250 ++++++++++++++++++ 3 files changed, 325 insertions(+), 57 deletions(-) delete mode 100644 tests/integration_tests/system/test_local_accounts.py create mode 100644 tests/integration_tests/system/test_user_session_manager_logins.py diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 9e6784c5..3ffc7b35 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union from prettytable import MARKDOWN, PrettyTable -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validate_call import primaite.simulator.network.nmne from primaite import getLogger @@ -989,6 +989,12 @@ class RemoteUserSession(UserSession): remote_ip_address: IPV4Address local: bool = False + @classmethod + def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa + return RemoteUserSession( + user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address + ) + def describe_state(self) -> Dict: state = super().describe_state() state["remote_ip_address"] = str(self.remote_ip_address) @@ -1066,7 +1072,9 @@ class UserSessionManager(Service): print(table.get_string(sortby="Step Last Active", reversesort=True)) def describe_state(self) -> Dict: - return super().describe_state() + state = super().describe_state() + state["active_remote_logins"] = len(self.remote_sessions) + return state @property def _user_manager(self) -> UserManager: @@ -1092,27 +1100,78 @@ class UserSessionManager(Service): self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity") - def login(self, username: str, password: str) -> Optional[str]: + @property + def remote_session_limit_reached(self) -> bool: + return len(self.remote_sessions) >= self.max_remote_sessions + + def validate_remote_session_uuid(self, remote_session_id: str) -> bool: + return remote_session_id in self.remote_sessions + + def _login( + self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None + ) -> Optional[str]: if not self._can_perform_action(): return None - user = self._user_manager.authenticate_user(username=username, password=password) - if user: - self.logout() - self.local_session = UserSession.create(user=user, timestep=self.current_timestep) - self.sys_log.info(f"{self.name}: User {user.username} logged in") - return self.local_session.uuid - else: - self.sys_log.info(f"{self.name}: Incorrect username or password") - def logout(self): + user = self._user_manager.authenticate_user(username=username, password=password) + + if not user: + self.sys_log.info(f"{self.name}: Incorrect username or password") + return None + + session_id = None + if local: + create_new_session = True + if self.local_session: + if self.local_session.user != user: + # logout the current user + self.local_logout() + else: + # not required as existing logged-in user attempting to re-login + create_new_session = False + + if create_new_session: + self.local_session = UserSession.create(user=user, timestep=self.current_timestep) + + session_id = self.local_session.uuid + else: + if not self.remote_session_limit_reached: + remote_session = RemoteUserSession.create( + user=user, timestep=self.current_timestep, remote_ip_address=remote_ip_address + ) + session_id = remote_session.uuid + self.remote_sessions[session_id] = remote_session + self.sys_log.info(f"{self.name}: User {user.username} logged in") + return session_id + + def local_login(self, username: str, password: str) -> Optional[str]: + return self._login(username=username, password=password, local=True) + + @validate_call() + def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]: + return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address) + + def _logout(self, local: bool = True, remote_session_id: Optional[str] = None): if not self._can_perform_action(): return False - if self.local_session: + session = None + if local and self.local_session: session = self.local_session session.end_step = self.current_timestep - self.historic_sessions.append(session) self.local_session = None + + if not local and remote_session_id: + session = self.remote_sessions.pop(remote_session_id) + if session: + self.historic_sessions.append(session) self.sys_log.info(f"{self.name}: User {session.user.username} logged out") + return + + def local_logout(self): + self._logout(local=True) + + def remote_logout(self, remote_session_id: str): + self._logout(local=False, remote_session_id=remote_session_id) @property def local_user_logged_in(self): @@ -1225,8 +1284,8 @@ class Node(SimComponent): def user_session_manager(self) -> UserSessionManager: return self.software_manager.software["UserSessionManager"] # noqa - def login(self, username: str, password: str) -> Optional[str]: - return self.user_session_manager.login(username, password) + def local_login(self, username: str, password: str) -> Optional[str]: + return self.user_session_manager.local_login(username, password) def logout(self): return self.user_session_manager.logout() @@ -1765,14 +1824,10 @@ class Node(SimComponent): :param pings: The number of pings to attempt, default is 4. :return: True if the ping is successful, otherwise False. """ - if not self.user_session_manager.local_user_logged_in: - return False if not isinstance(target_ip_address, IPv4Address): target_ip_address = IPv4Address(target_ip_address) if self.software_manager.icmp: - print("yes") return self.software_manager.icmp.ping(target_ip_address, pings) - print("no icmp") return False @abstractmethod diff --git a/tests/integration_tests/system/test_local_accounts.py b/tests/integration_tests/system/test_local_accounts.py deleted file mode 100644 index dbdbf857..00000000 --- a/tests/integration_tests/system/test_local_accounts.py +++ /dev/null @@ -1,37 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.container import Network -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server - - -def test_local_accounts_ping_temp(): - network = Network() - - # Create Computer - computer = Computer( - hostname="computer", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) - computer.power_on() - - # Create Server - server = Server( - hostname="server", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) - server.power_on() - - # Connect Computer and Server - network.connect(computer.network_interface[1], server.network_interface[1]) - - assert not computer.ping(server.network_interface[1].ip_address) - - computer.user_session_manager.login(username="admin", password="admin") - - assert computer.ping(server.network_interface[1].ip_address) diff --git a/tests/integration_tests/system/test_user_session_manager_logins.py b/tests/integration_tests/system/test_user_session_manager_logins.py new file mode 100644 index 00000000..955408ad --- /dev/null +++ b/tests/integration_tests/system/test_user_session_manager_logins.py @@ -0,0 +1,250 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple +from uuid import uuid4 + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server + + +@pytest.fixture(scope="function") +def client_server_network() -> Tuple[Computer, Server, Network]: + network = Network() + + client = Computer( + hostname="client", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + client.power_on() + + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + network.connect(client.network_interface[1], server.network_interface[1]) + + return client, server, network + + +def test_local_login_success(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + +def test_local_login_failure(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert not client.user_session_manager.local_user_logged_in + + +def test_new_user_local_login_success(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_manager.add_user(username="jane.doe", password="12345") + + client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert client.user_session_manager.local_user_logged_in + + +def test_new_local_login_clears_previous_login(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + current_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + client.user_manager.add_user(username="jane.doe", password="12345") + + new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "jane.doe" + + assert new_session_id != current_session_id + + +def test_new_local_login_attempt_same_uses_persists(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + current_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + new_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + assert new_session_id == current_session_id + + +def test_remote_login_success(client_server_network): + # partial test for now until we get the terminal application in so that amn actual remote connection can be made + client, server, network = client_server_network + + assert not server.user_session_manager.remote_sessions + + remote_session_id = server.user_session_manager.remote_login( + username="admin", password="admin", remote_ip_address="192.168.1.10" + ) + + assert server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + server.user_session_manager.remote_logout(remote_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + +def test_remote_login_failure(client_server_network): + # partial test for now until we get the terminal application in so that amn actual remote connection can be made + client, server, network = client_server_network + + assert not server.user_session_manager.remote_sessions + + remote_session_id = server.user_session_manager.remote_login( + username="jane.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + +def test_new_user_remote_login_success(client_server_network): + client, server, network = client_server_network + + server.user_manager.add_user(username="jane.doe", password="12345") + + remote_session_id = server.user_session_manager.remote_login( + username="jane.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + assert server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + server.user_session_manager.remote_logout(remote_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id) + + +def test_max_remote_sessions_same_user(client_server_network): + client, server, network = client_server_network + + remote_session_ids = [ + server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10") + for _ in range(server.user_session_manager.max_remote_sessions) + ] + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + +def test_max_remote_sessions_different_users(client_server_network): + client, server, network = client_server_network + + remote_session_ids = [] + + for i in range(server.user_session_manager.max_remote_sessions): + username = str(uuid4()) + password = "12345" + server.user_manager.add_user(username=username, password=password) + + remote_session_ids.append( + server.user_session_manager.remote_login( + username=username, password=password, remote_ip_address="192.168.1.10" + ) + ) + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + +def test_max_remote_sessions_limit_reached(client_server_network): + client, server, network = client_server_network + + remote_session_ids = [ + server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10") + for _ in range(server.user_session_manager.max_remote_sessions) + ] + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + assert len(server.user_session_manager.remote_sessions) == server.user_session_manager.max_remote_sessions + + fourth_attempt_session_id = server.user_session_manager.remote_login( + username="admin", password="admin", remote_ip_address="192.168.1.10" + ) + + assert not server.user_session_manager.validate_remote_session_uuid(fourth_attempt_session_id) + + assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids]) + + +def test_single_remote_logout_others_persist(client_server_network): + client, server, network = client_server_network + + server.user_manager.add_user(username="jane.doe", password="12345") + server.user_manager.add_user(username="john.doe", password="12345") + + admin_session_id = server.user_session_manager.remote_login( + username="admin", password="admin", remote_ip_address="192.168.1.10" + ) + + jane_session_id = server.user_session_manager.remote_login( + username="jane.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + john_session_id = server.user_session_manager.remote_login( + username="john.doe", password="12345", remote_ip_address="192.168.1.10" + ) + + server.user_session_manager.remote_logout(admin_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id) + + assert server.user_session_manager.validate_remote_session_uuid(jane_session_id) + + assert server.user_session_manager.validate_remote_session_uuid(john_session_id) + + server.user_session_manager.remote_logout(jane_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id) + + assert server.user_session_manager.validate_remote_session_uuid(john_session_id) + + server.user_session_manager.remote_logout(john_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id) + + assert not server.user_session_manager.validate_remote_session_uuid(john_session_id) From 0ac1c6702c7369163562fa6015cf22f22f8e0412 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 26 Jul 2024 16:56:03 +0100 Subject: [PATCH 34/95] #2713 - eod commit. Initial RequestManager Test implemented, along with an initial setup of the additional Request Manager methods. --- CHANGELOG.md | 2 +- .../system/services/terminal/terminal.py | 97 +++++++++++--- .../_system/_services/test_terminal.py | 125 +++++++++++++++++- 3 files changed, 205 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24ff83ed..b27244bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,7 +64,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added ability to log each agent's action choices in each step to a JSON file. - Removal of Link bandwidth hardcoding. This can now be configured via the network configuration yaml. Will default to 100 if not present. - Added NMAP application to all host and layer-3 network nodes. -- Added Terminal Class for HostNode components +- Added Terminal Class for HostNode components. ### Bug Fixes diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index f01b44a2..559e234c 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -2,9 +2,10 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from uuid import uuid4 +from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse @@ -35,17 +36,12 @@ class TerminalClientConnection(BaseModel): is_active: bool = True """Flag to state whether the connection is still active or not.""" - _dest_ip_address: IPv4Address + dest_ip_address: IPv4Address = None """Destination IP address of connection""" _connection_uuid: str = None """Connection UUID""" - @property - def dest_ip_address(self) -> Optional[IPv4Address]: - """Destination IP Address.""" - return self._dest_ip_address - @property def client(self) -> Optional[Terminal]: """The Terminal that holds this connection.""" @@ -95,6 +91,21 @@ class Terminal(Service): """Apply Terminal Request.""" return super().apply_request(request, context) + def show(self, markdown: bool = False): + """ + Display the remote connections to this terminal instance in tabular format. + + :param markdown: Whether to display the table in Markdown format or not. Default is `False`. + """ + table = PrettyTable(["Connection ID", "IP_Address", "Active"]) + if markdown: + table.set_style(MARKDOWN) + table.align = "l" + table.title = f"{self.sys_log.hostname} {self.name} Remote Connections" + for connection_id, connection in self.remote_connection.items(): + table.add_row([connection_id, connection.dest_ip_address, connection.is_active]) + print(table.get_string(sortby="Connection ID")) + def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" _login_valid = Terminal._LoginValidator(terminal=self) @@ -106,12 +117,52 @@ class Terminal(Service): func=lambda request, context: RequestResponse.from_bool(self.send()), validator=_login_valid ), ) - return rm - def _validate_login(self) -> bool: - """Validate login credentials are valid.""" - # return self.parent.UserSessionManager.validate_remote_session_uuid(self.connection_uuid) - return True + def _login(request: List[Any], context: Any) -> RequestResponse: + login = self._process_local_login(username=request[0], password=request[1]) + if login == True: + return RequestResponse(status="success", data={}) + else: + return RequestResponse(status="failure", data={}) + + def _remote_login(request: List[Any], context: Any) -> RequestResponse: + self._process_remote_login(username=request[0], password=request[1], ip_address=request[2]) + if self.is_connected: + return RequestResponse(status="success", data={}) + else: + return RequestResponse(status="failure", data={}) + + def _execute(request: List[Any], context: Any) -> RequestResponse: + """Execute an instruction.""" + command: str = request[0] + self.execute(command) + return RequestResponse(status="success", data={}) + + def _logoff() -> RequestResponse: + """Logoff from connection.""" + self.parent.UserSessionManager.logoff(self.connection_uuid) + self.disconnect(self.connection_uuid) + + return RequestResponse(status="success") + + rm.add_request( + "Login", + request_type=RequestType(func=_login), + ) + + rm.add_request( + "Remote Login", + request_type=RequestType(func=_remote_login), + ) + + rm.add_request( + "Execute", + request_type=RequestType(func=_execute), + ) + + rm.add_request("Logoff", request_type=RequestType(func=_logoff)) + + return rm class _LoginValidator(RequestPermissionValidator): """ @@ -155,7 +206,8 @@ class Terminal(Service): def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" # self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) - self.connection_uuid = str(uuid4()) + self.connection_uuid = str(uuid4()) # TODO: Remove following merging of UserSessionManager. + self.is_connected = True if self.connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") return True @@ -187,7 +239,7 @@ class Terminal(Service): self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) connection_uuid = str(uuid4()) - + self.is_connected = True if connection_uuid: # Send uuid to remote self.sys_log.info( @@ -203,14 +255,14 @@ class Terminal(Service): sender_ip_address=self.parent.network_interface[1].ip_address, target_ip_address=payload.sender_ip_address, ) - self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) self.remote_connection[connection_uuid] = TerminalClientConnection( parent_node=self.software_manager.node, - _dest_ip_address=payload.sender_ip_address, + dest_ip_address=payload.sender_ip_address, connection_uuid=connection_uuid, ) + self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) return True else: # UserSessionManager has returned None @@ -246,6 +298,9 @@ class Terminal(Service): self.is_connected = True return True + elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: + return self.execute(command=payload.payload) + else: self.sys_log.warning("Encounter unexpected message type, rejecting connection") return False @@ -254,6 +309,14 @@ class Terminal(Service): # %% Outbound + def execute(self, command: List[Any]) -> bool: + """Execute a passed ssh command via the request manager.""" + if command[0] == "install": + self.parent.software_manager.software.install(command[1]) + + return True + # TODO: Expand as necessary + def _disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from the remote.""" if not self.is_connected: @@ -266,7 +329,7 @@ class Terminal(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": self.remote_connection._connection_uuid}, + payload={"type": "disconnect", "connection_id": self.connection_uuid}, dest_ip_address=dest_ip_address, dest_port=self.port, ) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 65346b45..17af5699 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -3,12 +3,20 @@ from typing import Tuple import pytest +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import Terminal +from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState @@ -117,7 +125,7 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_a.login(ip_address="192.168.0.11") # login to computer_b + terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") # login to computer_b assert terminal_a.is_connected is True @@ -127,6 +135,121 @@ def test_terminal_ignores_when_off(basic_network): payload="Test_Payload", transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + sender_ip_address=computer_a.network_interface[1].ip_address, + target_ip_address="192.168.0.11", ) assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + + +def test_terminal_acknowledges_acl_rules(basic_network): + """Test that Terminal messages""" + + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") + + router = Router(hostname="router", num_ports=3, start_up_duration=0) + router.power_on() + router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") + router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") + + router.acl.add_rule(action=ACLAction.DENY, src_port=Port.SSH, dst_port=Port.SSH, position=22) + + +def test_network_simulation(basic_network): + # 0: Pull out the network + network = basic_network + + # 1: Set up network hardware + # 1.1: Configure the router + router = Router(hostname="router", num_ports=3, start_up_duration=0) + router.power_on() + router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") + router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") + + # 1.2: Create and connect switches + switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1.power_on() + network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) + router.enable_port(1) + switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2.power_on() + network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) + router.enable_port(2) + + # 1.3: Create and connect computer + client_1 = Computer( + hostname="client_1", + ip_address="10.0.1.2", + subnet_mask="255.255.255.0", + default_gateway="10.0.1.1", + start_up_duration=0, + ) + client_1.power_on() + network.connect( + endpoint_a=client_1.network_interface[1], + endpoint_b=switch_1.network_interface[1], + ) + + # 1.4: Create and connect servers + server_1 = Server( + hostname="server_1", + ip_address="10.0.2.2", + subnet_mask="255.255.255.0", + default_gateway="10.0.2.1", + start_up_duration=0, + ) + server_1.power_on() + network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) + + server_2 = Server( + hostname="server_2", + ip_address="10.0.2.3", + subnet_mask="255.255.255.0", + default_gateway="10.0.2.1", + start_up_duration=0, + ) + server_2.power_on() + network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) + + # 2: Configure base ACL + router.acl.add_rule(action=ACLAction.DENY, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.ICMP, position=23) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + + # 3: Install server software + server_1.software_manager.install(DNSServer) + dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa + dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address) + server_2.software_manager.install(WebServer) + + # 3.1: Ensure that the dns clients are configured correctly + client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address + server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address + + terminal_1: Terminal = client_1.software_manager.software.get("Terminal") + + assert terminal_1.login(username="admin", password="Admin123!", ip_address="192.168.0.11") is False + + +def test_terminal_receives_requests(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + network: Network = game.simulation.network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + + computer_b: Computer = network.get_node_by_hostname("node_b") + + assert terminal_a.is_connected is False + + action = ("TERMINAL_LOGIN", {"username": "admin", "password": "Admin123!"}) # TODO: Add Action to ActionManager ? + + agent.store_action(action) + game.step() + + assert terminal_a.is_connected is True From 2e35549c956ba33b32111f2714a4954b2ebfd532 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 29 Jul 2024 09:29:20 +0100 Subject: [PATCH 35/95] #2735 - added docstrings to the User, UserManager, and UserSessionManager classes --- .../simulator/network/hardware/base.py | 230 ++++++++++++++++-- 1 file changed, 213 insertions(+), 17 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 3ffc7b35..e33c6014 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,7 +6,7 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union +from typing import Any, Dict, List, Optional, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field, validate_call @@ -799,16 +799,23 @@ class User(SimComponent): """ Represents a user in the PrimAITE system. - :param username: The username of the user - :param password: The password of the user - :param disabled: Boolean flag indicating whether the user is disabled - :param is_admin: Boolean flag indicating whether the user has admin privileges + :ivar username: The username of the user + :ivar password: The password of the user + :ivar disabled: Boolean flag indicating whether the user is disabled + :ivar is_admin: Boolean flag indicating whether the user has admin privileges """ username: str + """The username of the user""" + password: str + """The password of the user""" + disabled: bool = False + """Boolean flag indicating whether the user is disabled""" + is_admin: bool = False + """Boolean flag indicating whether the user has admin privileges""" def describe_state(self) -> Dict: """ @@ -971,47 +978,131 @@ class UserManager(Service): class UserSession(SimComponent): + """ + Represents a user session on the Node. + + This class manages the state of a user session, including the user, session start, last active step, + and end step. It also indicates whether the session is local. + + :ivar user: The user associated with this session. + :ivar start_step: The timestep when the session was started. + :ivar last_active_step: The last timestep when the session was active. + :ivar end_step: The timestep when the session ended, if applicable. + :ivar local: Indicates if the session is local. Defaults to True. + """ + user: User + """The user associated with this session.""" + start_step: int + """The timestep when the session was started.""" + last_active_step: int + """The last timestep when the session was active.""" + end_step: Optional[int] = None + """The timestep when the session ended, if applicable.""" + local: bool = True + """Indicates if the session is local. Defaults to True.""" @classmethod def create(cls, user: User, timestep: int) -> UserSession: + """ + Creates a new instance of UserSession. + + This class method initialises a user session with the given user and timestep. + + :param user: The user associated with this session. + :param timestep: The timestep when the session is created. + :return: An instance of UserSession. + """ return UserSession(user=user, start_step=timestep, last_active_step=timestep) def describe_state(self) -> Dict: + """ + Describes the current state of the user session. + + :return: A dictionary representing the state of the user session. + """ return self.model_dump() class RemoteUserSession(UserSession): + """ + Represents a remote user session on the Node. + + This class extends the UserSession class to include additional attributes and methods specific to remote sessions. + + :ivar remote_ip_address: The IP address of the remote user. + :ivar local: Indicates that this is not a local session. Always set to False. + """ + remote_ip_address: IPV4Address + """The IP address of the remote user.""" + local: bool = False + """Indicates that this is not a local session. Always set to False.""" @classmethod def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa + """ + Creates a new instance of RemoteUserSession. + + This class method initialises a remote user session with the given user, timestep, and remote IP address. + + :param user: The user associated with this session. + :param timestep: The timestep when the session is created. + :param remote_ip_address: The IP address of the remote user. + :return: An instance of RemoteUserSession. + """ return RemoteUserSession( user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address ) def describe_state(self) -> Dict: + """ + Describes the current state of the remote user session. + + This method extends the base describe_state method to include the remote IP address. + + :return: A dictionary representing the state of the remote user session. + """ state = super().describe_state() state["remote_ip_address"] = str(self.remote_ip_address) return state class UserSessionManager(Service): + """ + Manages user sessions on a Node, including local and remote sessions. + + This class handles authentication, session management, and session timeouts for users interacting with the Node. + """ + node: Node + """The node associated with this UserSessionManager.""" + local_session: Optional[UserSession] = None + """The current local user session, if any.""" + remote_sessions: Dict[str, RemoteUserSession] = Field(default_factory=dict) + """A dictionary of active remote user sessions.""" + historic_sessions: List[UserSession] = Field(default_factory=list) + """A list of historic user sessions.""" local_session_timeout_steps: int = 30 + """The number of steps before a local session times out due to inactivity.""" + remote_session_timeout_steps: int = 5 + """The number of steps before a remote session times out due to inactivity.""" + max_remote_sessions: int = 3 + """The maximum number of concurrent remote sessions allowed.""" current_timestep: int = 0 + """The current timestep in the simulation.""" def __init__(self, **kwargs): """ @@ -1027,7 +1118,13 @@ class UserSessionManager(Service): self.start() def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False): - """Prints a table of the user sessions on the Node.""" + """ + Displays a table of the user sessions on the Node. + + :param markdown: Whether to display the table in markdown format. + :param include_session_id: Whether to include session IDs in the table. + :param include_historic: Whether to include historic sessions in the table. + """ headers = ["Session ID", "Username", "Type", "Remote IP", "Start Step", "Step Last Active", "End Step"] if not include_session_id: @@ -1041,6 +1138,14 @@ class UserSessionManager(Service): table.title = f"{self.node.hostname} User Sessions" def _add_session_to_table(user_session: UserSession): + """ + Adds a user session to the table for display. + + This helper function determines whether the session is local or remote and formats the session data + accordingly. It then adds the session data to the table. + + :param user_session: The user session to add to the table. + """ session_type = "local" remote_ip = "" if isinstance(user_session, RemoteUserSession): @@ -1072,12 +1177,22 @@ class UserSessionManager(Service): print(table.get_string(sortby="Step Last Active", reversesort=True)) def describe_state(self) -> Dict: + """ + Describes the current state of the UserSessionManager. + + :return: A dictionary representing the current state. + """ state = super().describe_state() state["active_remote_logins"] = len(self.remote_sessions) return state @property def _user_manager(self) -> UserManager: + """ + Returns the UserManager instance. + + :return: The UserManager instance. + """ return self.software_manager.software["UserManager"] # noqa def pre_timestep(self, timestep: int) -> None: @@ -1088,6 +1203,11 @@ class UserSessionManager(Service): self._timeout_session(self.local_session) def _timeout_session(self, session: UserSession) -> None: + """ + Handles session timeout logic. + + :param session: The session to be timed out. + """ session.end_step = self.current_timestep session_identity = session.user.username if session.local: @@ -1102,14 +1222,34 @@ class UserSessionManager(Service): @property def remote_session_limit_reached(self) -> bool: + """ + Checks if the maximum number of remote sessions has been reached. + + :return: True if the limit is reached, otherwise False. + """ return len(self.remote_sessions) >= self.max_remote_sessions def validate_remote_session_uuid(self, remote_session_id: str) -> bool: + """ + Validates if a given remote session ID exists. + + :param remote_session_id: The remote session ID to validate. + :return: True if the session ID exists, otherwise False. + """ return remote_session_id in self.remote_sessions def _login( - self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None + self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None ) -> Optional[str]: + """ + Logs a user in either locally or remotely. + + :param username: The username of the account. + :param password: The password of the account. + :param local: Whether the login is local or remote. + :param remote_ip_address: The remote IP address for remote login. + :return: The session ID if login is successful, otherwise None. + """ if not self._can_perform_action(): return None @@ -1145,13 +1285,35 @@ class UserSessionManager(Service): return session_id def local_login(self, username: str, password: str) -> Optional[str]: + """ + Logs a user in locally. + + :param username: The username of the account. + :param password: The password of the account. + :return: The session ID if login is successful, otherwise None. + """ return self._login(username=username, password=password, local=True) @validate_call() def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]: + """ + Logs a user in remotely. + + :param username: The username of the account. + :param password: The password of the account. + :param remote_ip_address: The remote IP address for the remote login. + :return: The session ID if login is successful, otherwise None. + """ return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address) - def _logout(self, local: bool = True, remote_session_id: Optional[str] = None): + def _logout(self, local: bool = True, remote_session_id: Optional[str] = None) -> bool: + """ + Logs a user out either locally or remotely. + + :param local: Whether the logout is local or remote. + :param remote_session_id: The remote session ID for remote logout. + :return: True if logout successful, otherwise False. + """ if not self._can_perform_action(): return False session = None @@ -1165,16 +1327,33 @@ class UserSessionManager(Service): if session: self.historic_sessions.append(session) self.sys_log.info(f"{self.name}: User {session.user.username} logged out") - return + return True + return False - def local_logout(self): - self._logout(local=True) + def local_logout(self) -> bool: + """ + Logs out the current local user. - def remote_logout(self, remote_session_id: str): - self._logout(local=False, remote_session_id=remote_session_id) + :return: True if logout successful, otherwise False. + """ + return self._logout(local=True) + + def remote_logout(self, remote_session_id: str) -> bool: + """ + Logs out a remote user by session ID. + + :param remote_session_id: The remote session ID. + :return: True if logout successful, otherwise False. + """ + return self._logout(local=False, remote_session_id=remote_session_id) @property - def local_user_logged_in(self): + def local_user_logged_in(self) -> bool: + """ + Checks if a local user is currently logged in. + + :return: True if a local user is logged in, otherwise False. + """ return self.local_session is not None @@ -1249,7 +1428,7 @@ class Node(SimComponent): """ Initialize the Node with various components and managers. - This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not + This method initialises the ARP cache, ICMP handler, session manager, and software manager if they are not provided. """ if not kwargs.get("sys_log"): @@ -1278,17 +1457,34 @@ class Node(SimComponent): @property def user_manager(self) -> UserManager: + """The Nodes User Manager.""" return self.software_manager.software["UserManager"] # noqa @property def user_session_manager(self) -> UserSessionManager: + """The Nodes User Session Manager.""" return self.software_manager.software["UserSessionManager"] # noqa def local_login(self, username: str, password: str) -> Optional[str]: + """ + Attempt to log in to the node uas a local user. + + This method attempts to authenticate a local user with the given username and password. If successful, it + returns a session token. If authentication fails, it returns None. + + :param username: The username of the account attempting to log in. + :param password: The password of the account attempting to log in. + :return: A session token if the login is successful, otherwise None. + """ return self.user_session_manager.local_login(username, password) - def logout(self): - return self.user_session_manager.logout() + def local_logout(self) -> None: + """ + Log out the current local user from the node. + + This method ends the current local user's session and invalidates the session token. + """ + return self.user_session_manager.local_logout() def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool: """ From 265632669ee9f947c0fca6916000899181e3b529 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 29 Jul 2024 10:29:12 +0100 Subject: [PATCH 36/95] #2778 - added request managers for USerManager and UserSessionManager classes --- .../simulator/network/hardware/base.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index e33c6014..0a561707 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -850,8 +850,29 @@ class UserManager(Service): kwargs["port"] = Port.NONE kwargs["protocol"] = IPProtocol.NONE super().__init__(**kwargs) + self._request_manager = None + self.start() + def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ + rm = super()._init_request_manager() + + # todo add doc about requeest schemas + rm.add_request( + "change_password", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.change_user_password(username=request[0], current_password=request[1], new_password=request[2]) + ) + ), + ) + return rm + def describe_state(self) -> Dict: """ Returns the state of the UserManager along with the number of users and admins. @@ -1117,6 +1138,34 @@ class UserSessionManager(Service): super().__init__(**kwargs) self.start() + def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ + rm = super()._init_request_manager() + + # todo add doc about requeest schemas + rm.add_request( + "remote_login", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.remote_login(username=request[0], password=request[1], remote_ip_address=request[2]) + ) + ), + ) + + rm.add_request( + "remote_logout", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.remote_logout(remote_session_id=request[0]) + ) + ), + ) + return rm + def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False): """ Displays a table of the user sessions on the Node. @@ -1686,6 +1735,10 @@ class Node(SimComponent): self._application_manager.add_request(name="install", request_type=RequestType(func=_install_application)) self._application_manager.add_request(name="uninstall", request_type=RequestType(func=_uninstall_application)) + rm.add_request("accounts", RequestType(func=self.user_manager._request_manager)) # noqa + + rm.add_request("sessions", RequestType(func=self.user_session_manager._request_manager)) # noqa + return rm def describe_state(self) -> Dict: @@ -1868,7 +1921,6 @@ class Node(SimComponent): def pre_timestep(self, timestep: int) -> None: """Apply pre-timestep logic.""" super().pre_timestep(timestep) - self._ for network_interface in self.network_interfaces.values(): network_interface.pre_timestep(timestep=timestep) From cf7341a4fda5994c4000ae5730d11921b5658ed0 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 29 Jul 2024 10:50:32 +0100 Subject: [PATCH 37/95] #2713 - Minor changes before merging into main Terminal branch --- .../system/services/terminal/terminal.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 559e234c..cadc8853 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -3,7 +3,6 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Any, Dict, List, Optional -from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel @@ -157,10 +156,10 @@ class Terminal(Service): rm.add_request( "Execute", - request_type=RequestType(func=_execute), + request_type=RequestType(func=_execute, validator=_login_valid), ) - rm.add_request("Logoff", request_type=RequestType(func=_logoff)) + rm.add_request("Logoff", request_type=RequestType(func=_logoff, validator=_login_valid)) return rm @@ -205,8 +204,7 @@ class Terminal(Service): def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" - # self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) - self.connection_uuid = str(uuid4()) # TODO: Remove following merging of UserSessionManager. + self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) self.is_connected = True if self.connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") @@ -233,12 +231,15 @@ class Terminal(Service): return self.send(payload=payload, dest_ip_address=ip_address) def _process_remote_login(self, payload: SSHPacket) -> bool: - """Processes a remote terminal requesting to login to this terminal.""" + """Processes a remote terminal requesting to login to this terminal. + + :param payload: The SSH Payload Packet. + :return: True if successful, else False. + """ username: str = payload.user_account.username password: str = payload.user_account.password self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") - # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) - connection_uuid = str(uuid4()) + connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) self.is_connected = True if connection_uuid: # Send uuid to remote @@ -270,7 +271,11 @@ class Terminal(Service): return False def receive(self, payload: SSHPacket, **kwargs) -> bool: - """Receive Payload and process for a response.""" + """Receive Payload and process for a response. + + :param payload: The message contents received. + :return: True if successfull, else False. + """ self.sys_log.debug(f"Received payload: {payload}") if not isinstance(payload, SSHPacket): @@ -286,11 +291,9 @@ class Terminal(Service): dest_ip_address = kwargs["dest_ip_address"] self.disconnect(dest_ip_address=dest_ip_address) self.sys_log.debug(f"Disconnecting {connection_id}") - # We need to close on the other machine as well elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - """Login Request Received.""" - self._process_remote_login(payload=payload) + return self._process_remote_login(payload=payload) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") @@ -311,11 +314,12 @@ class Terminal(Service): def execute(self, command: List[Any]) -> bool: """Execute a passed ssh command via the request manager.""" + # TODO: Expand as necessary, as new functionalilty is needed. if command[0] == "install": self.parent.software_manager.software.install(command[1]) - - return True - # TODO: Expand as necessary + return True + else: + return False def _disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from the remote.""" From f78cb24150ec8bac8f0be0970874df7e1836e850 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 29 Jul 2024 14:20:29 +0100 Subject: [PATCH 38/95] #2706 - Removed some un-necessary comments and changes to network used in terminal ACL unit test --- .../simulator/system/services/terminal/terminal.py | 6 ------ .../_simulator/_system/_services/test_terminal.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index cadc8853..3caf57be 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -72,8 +72,6 @@ class Terminal(Service): kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - # %% Util - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -182,8 +180,6 @@ class Terminal(Service): """Message that is reported when a request is rejected by this validator.""" return "Cannot perform request on terminal as not logged in." - # %% Inbound - def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: """Process User request to login to Terminal. @@ -310,8 +306,6 @@ class Terminal(Service): return True - # %% Outbound - def execute(self, command: List[Any]) -> bool: """Execute a passed ssh command via the request manager.""" # TODO: Expand as necessary, as new functionalilty is needed. diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 17af5699..e1241bbe 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -17,7 +17,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.services.web_server.web_server import WebServer -from primaite.simulator.system.software import SoftwareHealthState @pytest.fixture(scope="function") @@ -194,6 +193,14 @@ def test_network_simulation(basic_network): endpoint_b=switch_1.network_interface[1], ) + client_2 = Computer( + hostname="client_2", + ip_address="10.0.2.2", + subnet_mask="255.255.255.0", + ) + client_2.power_on() + network.connect(endpoint_a=client_2.network_interface[1], endpoint_b=switch_2.network_interface[1]) + # 1.4: Create and connect servers server_1 = Server( hostname="server_1", @@ -233,7 +240,7 @@ def test_network_simulation(basic_network): terminal_1: Terminal = client_1.software_manager.software.get("Terminal") - assert terminal_1.login(username="admin", password="Admin123!", ip_address="192.168.0.11") is False + assert terminal_1.login(username="admin", password="Admin123!", ip_address="10.0.2.2") is False def test_terminal_receives_requests(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): From 3d13669671403dbda1a60c07a65aab9f1e755328 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 29 Jul 2024 15:12:24 +0100 Subject: [PATCH 39/95] #2735: fixes to broken items --- src/primaite/simulator/network/hardware/base.py | 17 ++++++++++------- .../network/hardware/nodes/network/switch.py | 3 +++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 0a561707..08f14b7e 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -850,7 +850,6 @@ class UserManager(Service): kwargs["port"] = Port.NONE kwargs["protocol"] = IPProtocol.NONE super().__init__(**kwargs) - self._request_manager = None self.start() @@ -1499,20 +1498,28 @@ class Node(SimComponent): super().__init__(**kwargs) self.session_manager.node = self self.session_manager.software_manager = self.software_manager + self.software_manager.install(UserSessionManager, node=self) + self._request_manager.add_request( + "sessions", RequestType(func=self.user_session_manager._request_manager) + ) # noqa + self.software_manager.install(UserManager) + self._request_manager.add_request("accounts", RequestType(func=self.user_manager._request_manager)) # noqa + self.user_manager.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True) + self._install_system_software() @property def user_manager(self) -> UserManager: """The Nodes User Manager.""" - return self.software_manager.software["UserManager"] # noqa + return self.software_manager.software.get("UserManager") # noqa @property def user_session_manager(self) -> UserSessionManager: """The Nodes User Session Manager.""" - return self.software_manager.software["UserSessionManager"] # noqa + return self.software_manager.software.get("UserSessionManager") # noqa def local_login(self, username: str, password: str) -> Optional[str]: """ @@ -1735,10 +1742,6 @@ class Node(SimComponent): self._application_manager.add_request(name="install", request_type=RequestType(func=_install_application)) self._application_manager.add_request(name="uninstall", request_type=RequestType(func=_uninstall_application)) - rm.add_request("accounts", RequestType(func=self.user_manager._request_manager)) # noqa - - rm.add_request("sessions", RequestType(func=self.user_session_manager._request_manager)) # noqa - return rm def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 1a7da2e7..4324ac94 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -108,6 +108,9 @@ class Switch(NetworkNode): for i in range(1, self.num_ports + 1): self.connect_nic(SwitchPort()) + def _install_system_software(self): + pass + def show(self, markdown: bool = False): """ Prints a table of the SwitchPorts on the Switch. From 0fad61eaea2d1d39c94fe5241125292c5686fc71 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 29 Jul 2024 15:15:15 +0100 Subject: [PATCH 40/95] #2735: pipeline build fail if test fails --- .azure/azure-ci-build-pipeline.yaml | 4 +--- run_test_and_coverage.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) create mode 100644 run_test_and_coverage.py diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 01111290..2375a391 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -102,9 +102,7 @@ stages: version: '2.1.x' - script: | - coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml --cov-fail-under=80 - coverage xml -o coverage.xml -i - coverage html -d htmlcov -i + python run_test_and_coverage.py displayName: 'Run tests and code coverage' # Run the notebooks diff --git a/run_test_and_coverage.py b/run_test_and_coverage.py new file mode 100644 index 00000000..3bd9072d --- /dev/null +++ b/run_test_and_coverage.py @@ -0,0 +1,22 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import subprocess +import sys +from typing import Any + + +def run_command(command: Any): + """Runs a command and returns the exit code.""" + result = subprocess.run(command, shell=True) + if result.returncode != 0: + sys.exit(result.returncode) + + +# Run pytest with coverage +run_command( + "coverage run -m --source=primaite pytest -v -o junit_family=xunit2 " + "--junitxml=junit/test-results.xml --cov-fail-under=80" +) + +# Generate coverage reports if tests passed +run_command("coverage xml -o coverage.xml -i") +run_command("coverage html -d htmlcov -i") From e492f19a437b7aa119b524ac556ee91b99e1d900 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 29 Jul 2024 17:10:13 +0100 Subject: [PATCH 41/95] #2706 - Small change to execute method following feedback --- .../simulator/system/services/terminal/terminal.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 3caf57be..ca0d7c1f 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -306,14 +306,9 @@ class Terminal(Service): return True - def execute(self, command: List[Any]) -> bool: + def execute(self, command: List[Any]) -> RequestResponse: """Execute a passed ssh command via the request manager.""" - # TODO: Expand as necessary, as new functionalilty is needed. - if command[0] == "install": - self.parent.software_manager.software.install(command[1]) - return True - else: - return False + return self.parent.apply_request(command) def _disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from the remote.""" From c984d695cca3b2ac53d8ce7eff3fcce34aa43b94 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 29 Jul 2024 23:03:26 +0100 Subject: [PATCH 42/95] #2735: use ray version 2.32 until 2.33 is fixed --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9e919604..01be8d52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ - "ray[rllib] >= 2.20.0, < 3", + "ray[rllib] == 2.32.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", "sb3-contrib==2.1.0", From bb0ecb93a4b9070b66da36f51c44bd4eb5f49d74 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 09:57:47 +0100 Subject: [PATCH 43/95] #2706 - Correcting whitespace change in database_service.py and actioning some review comments --- src/primaite/simulator/network/protocols/ssh.py | 3 --- .../simulator/system/services/database/database_service.py | 2 +- src/primaite/simulator/system/services/terminal/terminal.py | 4 ++-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 5eb181a6..8671a1c8 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -7,9 +7,6 @@ from typing import Optional from primaite.interface.request import RequestResponse from primaite.simulator.network.protocols.packet import DataPacket -# TODO: Elaborate / Confirm / Validate - See 2709. -# Placeholder implementation for Terminal Class implementation. - class SSHTransportMessage(IntEnum): """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index f061b3c7..22ae0ff3 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -21,7 +21,7 @@ class DatabaseService(Service): """ A class for simulating a generic SQL Server service. - This class inherits from the `Service` class and provides methods to simulate a SQL database. + This class inherits from the `Service` class and provides methods to simulate a SQL database. """ password: Optional[str] = None diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index ca0d7c1f..884d3f5b 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -117,7 +117,7 @@ class Terminal(Service): def _login(request: List[Any], context: Any) -> RequestResponse: login = self._process_local_login(username=request[0], password=request[1]) - if login == True: + if login: return RequestResponse(status="success", data={}) else: return RequestResponse(status="failure", data={}) @@ -140,7 +140,7 @@ class Terminal(Service): self.parent.UserSessionManager.logoff(self.connection_uuid) self.disconnect(self.connection_uuid) - return RequestResponse(status="success") + return RequestResponse(status="success", data={}) rm.add_request( "Login", From ab267982404482907ade2f40af6a120a2d3bab24 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 10:23:34 +0100 Subject: [PATCH 44/95] #2706 - New test to check that the terminal can receive and process commmands. --- .../_system/_services/test_terminal.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index e1241bbe..7dd7c2b1 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -65,7 +65,7 @@ def test_terminal_not_on_switch(): def test_terminal_send(basic_network): - """Check that Terminal can send""" + """Test that Terminal can send valid commands.""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") @@ -82,6 +82,28 @@ def test_terminal_send(basic_network): assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) +def test_terminal_receive(basic_network): + """Test that terminal can receive and process commands""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + folder_name = "Downloads" + + payload: SSHPacket = SSHPacket( + payload=["file_system", "create", "folder", folder_name], + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + sender_ip_address=computer_a.network_interface[1].ip_address, + target_ip_address=computer_b.network_interface[1].ip_address, + ) + + assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) + + # Assert that the Folder has been correctly created + assert computer_b.file_system.get_folder(folder_name) + + def test_terminal_fail_when_closed(basic_network): """Ensure Terminal won't attempt to send/receive when off""" network: Network = basic_network @@ -254,7 +276,7 @@ def test_terminal_receives_requests(game_and_agent_fixture: Tuple[PrimaiteGame, assert terminal_a.is_connected is False - action = ("TERMINAL_LOGIN", {"username": "admin", "password": "Admin123!"}) # TODO: Add Action to ActionManager ? + action = ("TERMINAL_LOGIN", {"username": "admin", "password": "Admin123!"}) agent.store_action(action) game.step() From 2b33a6edb4fe63214421f9da9959718f74e493f2 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 11:04:55 +0100 Subject: [PATCH 45/95] #2706 - New unit test to show that Terminal is able to send/handle install commands --- .../_system/_services/test_terminal.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 7dd7c2b1..aad32863 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -13,6 +13,7 @@ from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import Terminal @@ -104,6 +105,26 @@ def test_terminal_receive(basic_network): assert computer_b.file_system.get_folder(folder_name) +def test_terminal_install(basic_network): + """Test that Terminal can successfully process an INSTALL request""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + payload: SSHPacket = SSHPacket( + payload=["software_manager", "application", "install", "RansomwareScript"], + transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + sender_ip_address=computer_a.network_interface[1].ip_address, + target_ip_address=computer_b.network_interface[1].ip_address, + ) + + terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) + + assert computer_b.software_manager.software.get("RansomwareScript") + + def test_terminal_fail_when_closed(basic_network): """Ensure Terminal won't attempt to send/receive when off""" network: Network = basic_network From 2f50feb0a068171ec5afb7eb99391ad963c5b749 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 11:11:08 +0100 Subject: [PATCH 46/95] #2706 - Removing redundant unit test from --- .../_system/_services/test_terminal.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index aad32863..411f0ebe 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -184,23 +184,6 @@ def test_terminal_ignores_when_off(basic_network): assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") -def test_terminal_acknowledges_acl_rules(basic_network): - """Test that Terminal messages""" - - network: Network = basic_network - computer_a: Computer = network.get_node_by_hostname("node_a") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") - - terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") - - router = Router(hostname="router", num_ports=3, start_up_duration=0) - router.power_on() - router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") - router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") - - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.SSH, dst_port=Port.SSH, position=22) - - def test_network_simulation(basic_network): # 0: Pull out the network network = basic_network From 09084574a87f22b6bd2aacc0766c3aa2c9b5a341 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 12:15:37 +0100 Subject: [PATCH 47/95] #2706 - Inclusion of health_state_actual attribute to the Terminal class. Started fleshing out a walkthrough notebook showing how to use the new component. --- .../notebooks/Terminal-Processing.ipynb | 157 ++++++++++++++++++ .../system/services/terminal/terminal.py | 6 +- 2 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 src/primaite/notebooks/Terminal-Processing.ipynb diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb new file mode 100644 index 00000000..6a197b03 --- /dev/null +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Terminal Processing\n", + "\n", + "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n", + "\n", + "By default, the Terminal will come pre-installed on any simulation component which inherits from `HostNode`, and simulates the Secure Shell (SSH) protocol as the communication method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.simulator.system.services.terminal.terminal import Terminal\n", + "from primaite.simulator.network.container import Network\n", + "from primaite.simulator.network.hardware.nodes.host.computer import Computer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def basic_network() -> Network:\n", + " \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n", + " network = Network()\n", + " node_a = Computer(hostname=\"node_a\", ip_address=\"192.168.0.10\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n", + " node_a.power_on()\n", + " node_b = Computer(hostname=\"node_b\", ip_address=\"192.168.0.11\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n", + " node_b.power_on()\n", + " network.connect(node_a.network_interface[1], node_b.network_interface[1])\n", + " return network" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "demonstrate how we obtain the Terminal component" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "network: Network = basic_network()\n", + "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", + "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", + "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", + "\n", + "# The below can be un-commented when UserSessionManager is implemented. Will need to login before sending any SSH commands\n", + "# to remote.\n", + "# terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Terminal can be used to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n", + "\n", + "Once ran and the command sent, the `RansomwareScript` can be seen in the list of applications on the `node_b Software Manager`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage\n", + "from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n", + "\n", + "computer_b.software_manager.show()\n", + "\n", + "payload: SSHPacket = SSHPacket(\n", + " payload=[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n", + " transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,\n", + " connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,\n", + " sender_ip_address=computer_a.network_interface[1].ip_address,\n", + " target_ip_address=computer_b.network_interface[1].ip_address,\n", + ")\n", + "\n", + "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)\n", + "\n", + "computer_b.software_manager.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The below example shows how you can send a command via the terminal to create a folder on the target Node.\n", + "\n", + "Here, we send a command to `computer_b` to create a new folder titled \"Downloads\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer_b.file_system.show()\n", + "\n", + "payload: SSHPacket = SSHPacket(\n", + " payload=[\"file_system\", \"create\", \"folder\", \"Downloads\"],\n", + " transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,\n", + " connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,\n", + " sender_ip_address=computer_a.network_interface[1].ip_address,\n", + " target_ip_address=computer_b.network_interface[1].ip_address,\n", + ")\n", + "\n", + "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)\n", + "\n", + "computer_b.file_system.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 884d3f5b..eae21804 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -20,6 +20,7 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +from primaite.simulator.system.software import SoftwareHealthState class TerminalClientConnection(BaseModel): @@ -62,7 +63,10 @@ class Terminal(Service): "Uuid for connection requests" operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING - """Initial Operating State""" + "Initial Operating State" + + health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD + "Service Health State" remote_connection: Dict[str, TerminalClientConnection] = {} From 5e3a16999952aab47983f99175937da94a577826 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 30 Jul 2024 12:48:11 +0100 Subject: [PATCH 48/95] #2735: add usermanager and usersessionmanager into describe_state --- src/primaite/simulator/network/hardware/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 08f14b7e..05e52e32 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1767,6 +1767,8 @@ class Node(SimComponent): "services": {svc.name: svc.describe_state() for svc in self.services.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()}, "revealed_to_red": self.revealed_to_red, + "user_manager": self.user_manager.describe_state(), + "user_session_manager": self.user_session_manager.describe_state(), } ) return state From 3698e6ff5fd20316979ec2c6cbe374ca7331850e Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 15:24:37 +0100 Subject: [PATCH 49/95] #2706 - Commented out references to UserSessionManager to remove the dependency. --- src/primaite/notebooks/Terminal-Processing.ipynb | 9 ++++++++- .../system/services/terminal/terminal.py | 16 +++++++++------- .../_system/_services/test_terminal.py | 15 +++++++++++++-- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 6a197b03..4cb962ca 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -15,7 +15,7 @@ "source": [ "This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n", "\n", - "By default, the Terminal will come pre-installed on any simulation component which inherits from `HostNode`, and simulates the Secure Shell (SSH) protocol as the communication method." + "By default, the Terminal will come pre-installed on any simulation component which inherits from `HostNode` (Computer, Server, Printer), and simulates the Secure Shell (SSH) protocol as the communication method." ] }, { @@ -131,6 +131,13 @@ "\n", "computer_b.file_system.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resultant call to `computer_b.file_system.show()` shows that the new folder has been created." + ] } ], "metadata": { diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index eae21804..50d30a34 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -3,6 +3,7 @@ from __future__ import annotations from ipaddress import IPv4Address from typing import Any, Dict, List, Optional +from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel @@ -88,10 +89,6 @@ class Terminal(Service): state = super().describe_state() return state - def apply_request(self, request: List[str | int | float | Dict], context: Dict | None = None) -> RequestResponse: - """Apply Terminal Request.""" - return super().apply_request(request, context) - def show(self, markdown: bool = False): """ Display the remote connections to this terminal instance in tabular format. @@ -141,7 +138,8 @@ class Terminal(Service): def _logoff() -> RequestResponse: """Logoff from connection.""" - self.parent.UserSessionManager.logoff(self.connection_uuid) + # TODO: Uncomment this when UserSessionManager merged. + # self.parent.UserSessionManager.logoff(self.connection_uuid) self.disconnect(self.connection_uuid) return RequestResponse(status="success", data={}) @@ -204,7 +202,9 @@ class Terminal(Service): def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" - self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + # TODO: Un-comment this when UserSessionManager is merged. + # self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + self.connection_uuid = str(uuid4()) self.is_connected = True if self.connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") @@ -239,7 +239,9 @@ class Terminal(Service): username: str = payload.user_account.username password: str = payload.user_account.password self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") - connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + # TODO: Un-comment this when UserSessionManager is merged. + # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) + connection_uuid = str(uuid4()) self.is_connected = True if connection_uuid: # Send uuid to remote diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 411f0ebe..8ec20394 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -45,6 +45,17 @@ def basic_network() -> Network: return network +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + def test_terminal_creation(terminal_on_computer): terminal, computer = terminal_on_computer terminal.describe_state() @@ -273,10 +284,10 @@ def test_terminal_receives_requests(game_and_agent_fixture: Tuple[PrimaiteGame, game, agent = game_and_agent_fixture network: Network = game.simulation.network - computer_a: Computer = network.get_node_by_hostname("node_a") + computer_a: Computer = network.get_node_by_hostname("client_1") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") - computer_b: Computer = network.get_node_by_hostname("node_b") + computer_b: Computer = network.get_node_by_hostname("client_2") assert terminal_a.is_connected is False From 0ed61ec79ba3f1a5f604949caecca26dfc7f80df Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 15:54:08 +0100 Subject: [PATCH 50/95] #2706 - Updates to terminal and host_node documentation, removal of redundant terminal unit test --- .../network/nodes/host_node.rst | 2 ++ .../system/services/terminal.rst | 1 + .../_system/_services/test_terminal.py | 19 ------------------- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/docs/source/simulation_components/network/nodes/host_node.rst b/docs/source/simulation_components/network/nodes/host_node.rst index 301cd783..b8aae098 100644 --- a/docs/source/simulation_components/network/nodes/host_node.rst +++ b/docs/source/simulation_components/network/nodes/host_node.rst @@ -49,3 +49,5 @@ fundamental network operations: 5. **NTP (Network Time Protocol) Client:** Synchronises the host's clock with network time servers. 6. **Web Browser:** A simulated application that allows the host to request and display web content. + +7. **Terminal:** A simulated service that allows the host to connect to remote hosts and execute commands. diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index 4d1285d1..4b02a6db 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -41,6 +41,7 @@ to provide User Credential authentication when receiving/processing commands. Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating the passing of requests to both. +A more detailed example of how to use the Terminal class can be found in the Terminal-Processing jupyter notebook. Python """""" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 8ec20394..d4592228 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -278,22 +278,3 @@ def test_network_simulation(basic_network): terminal_1: Terminal = client_1.software_manager.software.get("Terminal") assert terminal_1.login(username="admin", password="Admin123!", ip_address="10.0.2.2") is False - - -def test_terminal_receives_requests(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): - game, agent = game_and_agent_fixture - - network: Network = game.simulation.network - computer_a: Computer = network.get_node_by_hostname("client_1") - terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") - - computer_b: Computer = network.get_node_by_hostname("client_2") - - assert terminal_a.is_connected is False - - action = ("TERMINAL_LOGIN", {"username": "admin", "password": "Admin123!"}) - - agent.store_action(action) - game.step() - - assert terminal_a.is_connected is True From 06ac127f6bc90acbf40c7b4fb3b19248f9f95e65 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 30 Jul 2024 16:58:40 +0100 Subject: [PATCH 51/95] #2706 - Updates to Terminal Processing notebook to highlight utility function and improve formatting --- .../notebooks/Terminal-Processing.ipynb | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 4cb962ca..c9321b01 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -63,19 +63,33 @@ "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", + "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")\n", "\n", - "# The below can be un-commented when UserSessionManager is implemented. Will need to login before sending any SSH commands\n", - "# to remote.\n", - "# terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)" + "# Login to the remote (node_b) from local (node_a)\n", + "terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The Terminal can be used to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n", - "\n", - "Once ran and the command sent, the `RansomwareScript` can be seen in the list of applications on the `node_b Software Manager`. " + "You can view all remote connections to a terminal through use of the `show()` method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminal_b.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Terminal can be used to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n" ] }, { @@ -97,8 +111,23 @@ " target_ip_address=computer_b.network_interface[1].ip_address,\n", ")\n", "\n", - "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)\n", - "\n", + "# Send commmand to install RansomwareScript\n", + "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `RansomwareScript` can then be seen in the list of applications on the `node_b Software Manager`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "computer_b.software_manager.show()" ] }, From 9bf8d0f8cbce18542622bf772fd9abb1edf50bc6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 Jul 2024 13:20:15 +0100 Subject: [PATCH 52/95] #2676 Put NMNE back into network module --- src/primaite/game/game.py | 4 +- src/primaite/session/io.py | 45 ------------------- .../simulator/network/hardware/base.py | 2 +- src/primaite/simulator/network/nmne.py | 25 +++++++++++ .../observations/test_nic_observations.py | 4 +- .../network/test_capture_nmne.py | 8 ++-- 6 files changed, 34 insertions(+), 54 deletions(-) create mode 100644 src/primaite/simulator/network/nmne.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cd0180db..2e7ee735 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -16,7 +16,6 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort -from primaite.session.io import store_nmne_config from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState @@ -27,6 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application @@ -265,7 +265,7 @@ class PrimaiteGame: nodes_cfg = network_config.get("nodes", []) links_cfg = network_config.get("links", []) # Set the NMNE capture config - NetworkInterface.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) + NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) for node_cfg in nodes_cfg: n_type = node_cfg["type"] diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index c634e835..78d7cb3c 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -131,48 +131,3 @@ class PrimaiteIO: new = cls(settings=cls.Settings(**config)) return new - - -class NMNEConfig(BaseModel): - """Store all the information to perform NMNE operations.""" - - capture_nmne: bool = False - """Indicates whether Malicious Network Events (MNEs) should be captured.""" - nmne_capture_keywords: List[str] = [] - """List of keywords to identify malicious network events.""" - capture_by_direction: bool = True - """Captures should be organized by traffic direction (inbound/outbound).""" - capture_by_ip_address: bool = False - """Captures should be organized by source or destination IP address.""" - capture_by_protocol: bool = False - """Captures should be organized by network protocol (e.g., TCP, UDP).""" - capture_by_port: bool = False - """Captures should be organized by source or destination port.""" - capture_by_keyword: bool = False - """Captures should be filtered and categorised based on specific keywords.""" - - -def store_nmne_config(nmne_config: Dict) -> NMNEConfig: - """ - Store configuration for capturing Malicious Network Events (MNEs). - - This function updates settings related to NMNE capture, stored in NMNEConfig including whether - to capture NMNEs and the keywords to use for identifying NMNEs. - - The function ensures that the settings are updated only if they are provided in the - `nmne_config` dictionary, and maintains type integrity by relying on pydantic validators. - - :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys - include: - "capture_nmne" (bool) to indicate whether NMNEs should be captured; - "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. - :rvar class with data read from config file. - """ - nmne_capture_keywords: List[str] = [] - # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect - capture_nmne = nmne_config.get("capture_nmne", False) - - # Update the NMNE capture keywords, appending new keywords if provided - nmne_capture_keywords += nmne_config.get("nmne_capture_keywords", []) - - return NMNEConfig(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index aafdbe5c..50549389 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -14,12 +14,12 @@ from pydantic import BaseModel, Field from primaite import getLogger from primaite.exceptions import NetworkError from primaite.interface.request import RequestResponse -from primaite.session.io import NMNEConfig from primaite.simulator import SIM_OUTPUT from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.system.applications.application import Application diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py new file mode 100644 index 00000000..c9cff5de --- /dev/null +++ b/src/primaite/simulator/network/nmne.py @@ -0,0 +1,25 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import List + +from pydantic import BaseModel, ConfigDict + + +class NMNEConfig(BaseModel): + """Store all the information to perform NMNE operations.""" + + model_config = ConfigDict(extra="forbid") + + capture_nmne: bool = False + """Indicates whether Malicious Network Events (MNEs) should be captured.""" + nmne_capture_keywords: List[str] = [] + """List of keywords to identify malicious network events.""" + capture_by_direction: bool = True + """Captures should be organized by traffic direction (inbound/outbound).""" + capture_by_ip_address: bool = False + """Captures should be organized by source or destination IP address.""" + capture_by_protocol: bool = False + """Captures should be organized by network protocol (e.g., TCP, UDP).""" + capture_by_port: bool = False + """Captures should be organized by source or destination port.""" + capture_by_keyword: bool = False + """Captures should be filtered and categorised based on specific keywords.""" diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 7f86d26d..ef789ba7 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -9,11 +9,11 @@ from gymnasium import spaces from primaite.game.agent.interface import ProxyAgent from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame -from primaite.session.io import store_nmne_config from primaite.simulator.network.hardware.base import NetworkInterface from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.web_browser import WebBrowser @@ -87,7 +87,7 @@ def test_nic(simulation): } # Apply the NMNE configuration settings - NetworkInterface.nmne_config = store_nmne_config(nmne_config) + NetworkInterface.nmne_config = NMNEConfig(**nmne_config) assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index b4162e58..debf5b1c 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,9 +1,9 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from primaite.game.agent.observations.nic_observations import NICObservation -from primaite.session.io import store_nmne_config from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection @@ -35,7 +35,7 @@ def test_capture_nmne(uc2_network: Network): } # Apply the NMNE configuration settings - NIC.nmne_config = store_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers assert web_server_nic.nmne == {} @@ -112,7 +112,7 @@ def test_describe_state_nmne(uc2_network: Network): } # Apply the NMNE configuration settings - NIC.nmne_config = store_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Assert that initially, there are no captured MNEs on both web and database servers web_server_nic_state = web_server_nic.describe_state() @@ -221,7 +221,7 @@ def test_capture_nmne_observations(uc2_network: Network): } # Apply the NMNE configuration settings - NIC.nmne_config = store_nmne_config(nmne_config) + NIC.nmne_config = NMNEConfig(**nmne_config) # Define observations for the NICs of the database and web servers db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True) From bd1e23db7df686e1e50a5e5850a0a45c4dc509d5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 Jul 2024 15:25:02 +0100 Subject: [PATCH 53/95] 2676 - make ntwk intf use default nmne config --- src/primaite/simulator/network/hardware/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 50549389..6a25cbef 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -99,7 +99,7 @@ class NetworkInterface(SimComponent, ABC): pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." - nmne_config: ClassVar[NMNEConfig] = None + nmne_config: ClassVar[NMNEConfig] = NMNEConfig() "A dataclass defining malicious network events to be captured." nmne: Dict = Field(default_factory=lambda: {}) @@ -1167,7 +1167,7 @@ class Node(SimComponent): ip_address, network_interface.speed, "Enabled" if network_interface.enabled else "Disabled", - network_interface.nmne if self.nmne_config.capture_nmne else "Disabled", + network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled", ] ) print(table) From 0f3fa79ffea3adeeecdfbe00e60526bcf8b2f773 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 31 Jul 2024 15:47:18 +0100 Subject: [PATCH 54/95] #2706 - Actioning review comments on example notebook and terminal class --- src/primaite/notebooks/Terminal-Processing.ipynb | 8 +++++--- .../simulator/system/services/terminal/terminal.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index c9321b01..75b92422 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -50,7 +50,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "demonstrate how we obtain the Terminal component" + "The terminal can be accessed from a `HostNode` via the `software_manager` as demonstrated below. \n", + "\n", + "In the example, we have a basic network consisting of two computers " ] }, { @@ -89,7 +91,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The Terminal can be used to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n" + "The Terminal can be used to send requests to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n" ] }, { @@ -111,7 +113,7 @@ " target_ip_address=computer_b.network_interface[1].ip_address,\n", ")\n", "\n", - "# Send commmand to install RansomwareScript\n", + "# Send command to install RansomwareScript\n", "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)" ] }, diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 50d30a34..b6999694 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -124,13 +124,13 @@ class Terminal(Service): return RequestResponse(status="failure", data={}) def _remote_login(request: List[Any], context: Any) -> RequestResponse: - self._process_remote_login(username=request[0], password=request[1], ip_address=request[2]) - if self.is_connected: + login = self._process_remote_login(username=request[0], password=request[1], ip_address=request[2]) + if login: return RequestResponse(status="success", data={}) else: return RequestResponse(status="failure", data={}) - def _execute(request: List[Any], context: Any) -> RequestResponse: + def _execute_request(request: List[Any], context: Any) -> RequestResponse: """Execute an instruction.""" command: str = request[0] self.execute(command) @@ -156,7 +156,7 @@ class Terminal(Service): rm.add_request( "Execute", - request_type=RequestType(func=_execute, validator=_login_valid), + request_type=RequestType(func=_execute_request, validator=_login_valid), ) rm.add_request("Logoff", request_type=RequestType(func=_logoff, validator=_login_valid)) From 2abd1969fe618160df7e77b2899c7e0ab0c4f5bd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 Jul 2024 16:41:59 +0100 Subject: [PATCH 55/95] #2800 - Consolidate software install and uninstall to a single method --- .../simulator/network/hardware/base.py | 68 ------------------ .../simulator/system/core/software_manager.py | 70 ++++++++++--------- tests/conftest.py | 12 ++-- .../test_action_integration.py | 3 +- .../system/test_service_on_node.py | 4 +- .../test_simulation/test_request_response.py | 6 +- .../_network/_hardware/test_node_actions.py | 17 +++-- 7 files changed, 61 insertions(+), 119 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 15c44821..fd3f369d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1455,74 +1455,6 @@ class Node(SimComponent): else: return - def install_service(self, service: Service) -> None: - """ - Install a service on this node. - - :param service: Service instance that has not been installed on any node yet. - :type service: Service - """ - if service in self: - _LOGGER.warning(f"Can't add service {service.name} to node {self.hostname}. It's already installed.") - return - self.services[service.uuid] = service - service.parent = self - service.install() # Perform any additional setup, such as creating files for this service on the node. - self.sys_log.info(f"Installed service {service.name}") - _LOGGER.debug(f"Added service {service.name} to node {self.hostname}") - self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager)) - - def uninstall_service(self, service: Service) -> None: - """ - Uninstall and completely remove service from this node. - - :param service: Service object that is currently associated with this node. - :type service: Service - """ - if service not in self: - _LOGGER.warning(f"Can't remove service {service.name} from node {self.hostname}. It's not installed.") - return - service.uninstall() # Perform additional teardown, such as removing files or restarting the machine. - self.services.pop(service.uuid) - service.parent = None - self.sys_log.info(f"Uninstalled service {service.name}") - self._service_request_manager.remove_request(service.name) - - def install_application(self, application: Application) -> None: - """ - Install an application on this node. - - :param application: Application instance that has not been installed on any node yet. - :type application: Application - """ - if application in self: - _LOGGER.warning( - f"Can't add application {application.name} to node {self.hostname}. It's already installed." - ) - return - self.applications[application.uuid] = application - application.parent = self - self.sys_log.info(f"Installed application {application.name}") - _LOGGER.debug(f"Added application {application.name} to node {self.hostname}") - self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager)) - - def uninstall_application(self, application: Application) -> None: - """ - Uninstall and completely remove application from this node. - - :param application: Application object that is currently associated with this node. - :type application: Application - """ - if application not in self: - _LOGGER.warning( - f"Can't remove application {application.name} from node {self.hostname}. It's not installed." - ) - return - self.applications.pop(application.uuid) - application.parent = None - self.sys_log.info(f"Uninstalled application {application.name}") - self._application_request_manager.remove_request(application.name) - def _shut_down_actions(self): """Actions to perform when the node is shut down.""" # Turn off all the services in the node diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index e2266c2d..9c4d7cf6 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from prettytable import MARKDOWN, PrettyTable +from primaite.simulator.core import RequestType from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.network.transmission.network_layer import IPProtocol @@ -20,9 +21,7 @@ if TYPE_CHECKING: from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP -from typing import Type, TypeVar - -IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware) +from typing import Type class SoftwareManager: @@ -51,7 +50,7 @@ class SoftwareManager: self.node = parent_node self.session_manager = session_manager self.software: Dict[str, Union[Service, Application]] = {} - self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {} + self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {} self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {} self.sys_log: SysLog = sys_log self.file_system: FileSystem = file_system @@ -104,33 +103,34 @@ class SoftwareManager: return True return False - def install(self, software_class: Type[IOSoftwareClass]): + def install(self, software_class: Type[IOSoftware]): """ Install an Application or Service. :param software_class: The software class. """ - # TODO: Software manager and node itself both have an install method. Need to refactor to have more logical - # separation of concerns. if software_class in self._software_class_to_name_map: self.sys_log.warning(f"Cannot install {software_class} as it is already installed") return software = software_class( software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server ) + software.parent = self.node if isinstance(software, Application): - software.install() + self.node.applications[software.uuid] = software + self.node._application_request_manager.add_request( + software.name, RequestType(func=software._request_manager) + ) + elif isinstance(software, Service): + self.node.services[software.uuid] = software + self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager)) + software.install() software.software_manager = self self.software[software.name] = software self.port_protocol_mapping[(software.port, software.protocol)] = software if isinstance(software, Application): software.operating_state = ApplicationOperatingState.CLOSED - - # add the software to the node's registry after it has been fully initialized - if isinstance(software, Service): - self.node.install_service(software) - elif isinstance(software, Application): - self.node.install_application(software) + self.node.sys_log.info(f"Installed {software.name}") def uninstall(self, software_name: str): """ @@ -138,25 +138,31 @@ class SoftwareManager: :param software_name: The software name. """ - if software_name in self.software: - self.software[software_name].uninstall() - software = self.software.pop(software_name) # noqa - if isinstance(software, Application): - self.node.uninstall_application(software) - elif isinstance(software, Service): - self.node.uninstall_service(software) - for key, value in self.port_protocol_mapping.items(): - if value.name == software_name: - self.port_protocol_mapping.pop(key) - break - for key, value in self._software_class_to_name_map.items(): - if value == software_name: - self._software_class_to_name_map.pop(key) - break - del software - self.sys_log.info(f"Uninstalled {software_name}") + if software_name not in self.software: + self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") return - self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") + + self.software[software_name].uninstall() + software = self.software.pop(software_name) # noqa + if isinstance(software, Application): + self.node.applications.pop(software.uuid) + self.node._application_request_manager.remove_request(software.name) + elif isinstance(software, Service): + self.node.services.pop(software.uuid) + software.uninstall() + self.node._service_request_manager.remove_request(software.name) + software.parent = None + for key, value in self.port_protocol_mapping.items(): + if value.name == software_name: + self.port_protocol_mapping.pop(key) + break + for key, value in self._software_class_to_name_map.items(): + if value == software_name: + self._software_class_to_name_map.pop(key) + break + del software + self.sys_log.info(f"Uninstalled {software_name}") + return def send_internal_payload(self, target_software: str, payload: Any): """ diff --git a/tests/conftest.py b/tests/conftest.py index 54519e2b..ca704461 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,14 +37,14 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) -class TestService(Service): +class DummyService(Service): """Test Service class""" def describe_state(self) -> Dict: return super().describe_state() def __init__(self, **kwargs): - kwargs["name"] = "TestService" + kwargs["name"] = "DummyService" kwargs["port"] = Port.HTTP kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) @@ -75,15 +75,15 @@ def uc2_network() -> Network: @pytest.fixture(scope="function") -def service(file_system) -> TestService: - return TestService( - name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") +def service(file_system) -> DummyService: + return DummyService( + name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service") ) @pytest.fixture(scope="function") def service_class(): - return TestService + return DummyService @pytest.fixture(scope="function") diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py index a6f09436..7bdc80fc 100644 --- a/tests/integration_tests/component_creation/test_action_integration.py +++ b/tests/integration_tests/component_creation/test_action_integration.py @@ -22,8 +22,7 @@ def test_passing_actions_down(monkeypatch) -> None: for n in [pc1, pc2, srv, s1]: sim.network.add_node(n) - database_service = DatabaseService(file_system=srv.file_system) - srv.install_service(database_service) + srv.software_manager.install(DatabaseService) downloads_folder = pc1.file_system.create_folder("downloads") pc1.file_system.create_file("bermuda_triangle.png", folder_name="downloads") diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py index 15dbaf1d..cf9728ce 100644 --- a/tests/integration_tests/system/test_service_on_node.py +++ b/tests/integration_tests/system/test_service_on_node.py @@ -23,7 +23,7 @@ def populated_node( server.power_on() server.software_manager.install(service_class) - service = server.software_manager.software.get("TestService") + service = server.software_manager.software.get("DummyService") service.start() return server, service @@ -42,7 +42,7 @@ def test_service_on_offline_node(service_class): computer.power_on() computer.software_manager.install(service_class) - service: Service = computer.software_manager.software.get("TestService") + service: Service = computer.software_manager.software.get("DummyService") computer.power_off() diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py index a9f0b58d..95634cf1 100644 --- a/tests/integration_tests/test_simulation/test_request_response.py +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port -from tests.conftest import DummyApplication, TestService +from tests.conftest import DummyApplication, DummyService def test_successful_node_file_system_creation_request(example_network): @@ -61,7 +61,7 @@ def test_successful_application_requests(example_network): def test_successful_service_requests(example_network): net = example_network server_1 = net.get_node_by_hostname("server_1") - server_1.software_manager.install(TestService) + server_1.software_manager.install(DummyService) # Careful: the order here is important, for example we cannot run "stop" unless we run "start" first for verb in [ @@ -77,7 +77,7 @@ def test_successful_service_requests(example_network): "scan", "fix", ]: - resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb]) + resp_1 = net.apply_request(["node", "server_1", "service", "DummyService", verb]) assert resp_1 == RequestResponse(status="success", data={}) server_1.apply_timestep(timestep=1) server_1.apply_timestep(timestep=1) diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py index 9b37ac80..44c5c781 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/test_node_actions.py @@ -7,6 +7,7 @@ from primaite.simulator.file_system.folder import Folder from primaite.simulator.network.hardware.base import Node, NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.system.software import SoftwareHealthState +from tests.conftest import DummyApplication, DummyService @pytest.fixture @@ -47,7 +48,7 @@ def test_node_shutdown(node): assert node.operating_state == NodeOperatingState.OFF -def test_node_os_scan(node, service, application): +def test_node_os_scan(node): """Test OS Scanning.""" node.operating_state = NodeOperatingState.ON @@ -55,13 +56,15 @@ def test_node_os_scan(node, service, application): # TODO implement processes # add services to node + node.software_manager.install(DummyService) + service = node.software_manager.software.get("DummyService") service.set_health_state(SoftwareHealthState.COMPROMISED) - node.install_service(service=service) assert service.health_state_visible == SoftwareHealthState.UNUSED # add application to node + node.software_manager.install(DummyApplication) + application = node.software_manager.software.get("DummyApplication") application.set_health_state(SoftwareHealthState.COMPROMISED) - node.install_application(application=application) assert application.health_state_visible == SoftwareHealthState.UNUSED # add folder and file to node @@ -91,7 +94,7 @@ def test_node_os_scan(node, service, application): assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT -def test_node_red_scan(node, service, application): +def test_node_red_scan(node): """Test revealing to red""" node.operating_state = NodeOperatingState.ON @@ -99,12 +102,14 @@ def test_node_red_scan(node, service, application): # TODO implement processes # add services to node - node.install_service(service=service) + node.software_manager.install(DummyService) + service = node.software_manager.software.get("DummyService") assert service.revealed_to_red is False # add application to node + node.software_manager.install(DummyApplication) + application = node.software_manager.software.get("DummyApplication") application.set_health_state(SoftwareHealthState.COMPROMISED) - node.install_application(application=application) assert application.revealed_to_red is False # add folder and file to node From 2648614f97d2424c31e4fc1c208ebe0ce12dbd69 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 31 Jul 2024 16:44:25 +0100 Subject: [PATCH 56/95] 2800 update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 515be435..cebc2569 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits. - **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes. - **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity. +- **Software (un)install refactored**: Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality. ### Fixed From e4e3e17f511322ce1f5a5735a071d4518ff5a2f5 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 1 Aug 2024 07:57:01 +0100 Subject: [PATCH 57/95] #2706 - commit minor changes from review comments --- src/primaite/simulator/system/services/terminal/terminal.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index b6999694..6df21618 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -185,9 +185,11 @@ class Terminal(Service): def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: """Process User request to login to Terminal. - :param dest_ip_address: The IP address of the node we want to connect to. + If ip_address is passed, login will attempt a remote login to the terminal + :param username: The username credential. :param password: The user password component of credentials. + :param dest_ip_address: The IP address of the node we want to connect to. :return: True if successful, False otherwise. """ if self.operating_state != ServiceOperatingState.RUNNING: @@ -196,6 +198,8 @@ class Terminal(Service): if ip_address: # if ip_address has been provided, we assume we are logging in to a remote terminal. + if ip_address == self.parent.network_interface[1].ip_address: + return self._process_local_login(username=username, password=password) return self._send_remote_login(username=username, password=password, ip_address=ip_address) return self._process_local_login(username=username, password=password) From 5ef9e78a448192ec58f66429c80588bda84f93f7 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 1 Aug 2024 08:37:51 +0100 Subject: [PATCH 58/95] #2706 - Elaborated on terminal login within notebook --- .../notebooks/Terminal-Processing.ipynb | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 75b92422..fc795794 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -65,10 +65,25 @@ "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", - "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")\n", - "\n", + "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are logging in to the 'admin' account on `node_b`. \n", + "If you are not logged in, any commands sent will be rejected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "# Login to the remote (node_b) from local (node_a)\n", - "terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)\n" + "terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)" ] }, { From b5992574339c2d28b5ab954d566d478317c4fc4e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Aug 2024 09:06:35 +0100 Subject: [PATCH 59/95] #2676 - update configs to use new nmne schema; fix test and warnings --- .../_package_data/scenario_with_placeholders/scenario.yaml | 4 ++++ src/primaite/simulator/network/protocols/icmp.py | 4 ++-- tests/conftest.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml index 81848b2d..dfd200f3 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml @@ -129,6 +129,10 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - hostname: client type: computer diff --git a/src/primaite/simulator/network/protocols/icmp.py b/src/primaite/simulator/network/protocols/icmp.py index 743e2375..9f0626f0 100644 --- a/src/primaite/simulator/network/protocols/icmp.py +++ b/src/primaite/simulator/network/protocols/icmp.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Union from pydantic import BaseModel, field_validator, validate_call -from pydantic_core.core_schema import FieldValidationInfo +from pydantic_core.core_schema import ValidationInfo from primaite import getLogger @@ -96,7 +96,7 @@ class ICMPPacket(BaseModel): @field_validator("icmp_code") # noqa @classmethod - def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int: + def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int: """Validates the icmp_type and icmp_code.""" icmp_type = info.data["icmp_type"] if get_icmp_type_code_description(icmp_type, v): diff --git a/tests/conftest.py b/tests/conftest.py index 54519e2b..2996e953 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer from tests import TEST_ASSETS_ROOT -rayinit(local_mode=True) +rayinit() ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 From 2a715d8d0a6c5c7d84871189daddc7c6268e6ab7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Aug 2024 11:08:41 +0100 Subject: [PATCH 60/95] backport 3.2 changes to core --- CHANGELOG.md | 355 +++++++----------- CONTRIBUTING.md | 3 - _config.yml | 3 + pyproject.toml | 2 +- src/primaite/game/agent/actions.py | 2 +- src/primaite/notebooks/Action-masking.ipynb | 9 +- .../Training-an-RLLIB-MARL-System.ipynb | 14 +- .../notebooks/Training-an-RLLib-Agent.ipynb | 13 +- .../simulator/network/hardware/base.py | 2 +- .../test_software_fix_duration.py | 12 +- 10 files changed, 154 insertions(+), 261 deletions(-) create mode 100644 _config.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index 515be435..b5996f98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,257 +5,172 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] -### Added -- **show_bandwidth_load Function**: Displays current bandwidth load for each frequency in the airspace. -- **Bandwidth Tracking**: Tracks data transmission across each frequency. -- **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files. -- **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. +## [3.2.0] - 2024-07-18 + +### Added +- Action penalty is a reward component that applies a negative reward for doing any action other than DONOTHING +- Application configuration actions for RansomwareScript, DatabaseClient, and DoSBot applications +- Ability to configure how long it takes to apply the service fix action +- Terminal service using SSH +- Airspaces now track the amount of data being transmitted, viewable using the `show_bandwidth_load` method +- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML +- Agent logging for agents' internal decision logic +- Action masking in all PrimAITE environments ### Changed - -- **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`. -- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits. -- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes. -- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity. +- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed +- Databases can no longer respond to request while performing a backup +- Application install no longer accepts an `ip_address` parameter +- Application install action can now be used on all applications +- Actions have additional logic for checking validity +- Frame `size` attribute now includes both core size and payload size in bytes +- The `speed` attribute of `NetworkInterface` has been changed from `int` to `float` +- Tidied up CHANGELOG ### Fixed - -- **Transmission Permission Logic**: Corrected the logic in `can_transmit_frame` to accurately prevent overloads by checking if the transmission of a frame stays within allowable bandwidth limits after considering current load. +- Links and airspaces can no longer transmit data if this would exceed their bandwidth -[//]: # (This file needs tidying up between 2.0.0 and this line as it hasn't been segmented into 3.0.0 and 3.1.0 and isn't compliant with https://keepachangelog.com/en/1.1.0/) - -## 3.0.0b9 -- Removed deprecated `PrimaiteSession` class. -- Added ability to set log levels via configuration. -- Upgraded pydantic to version 2.7.0 -- Upgraded Ray to version >= 2.9 -- Added ipywidgets to the dependencies -- Added ability to define scenarios that change depending on the episode number. -- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config` -- Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient -- Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's. -- Implemented the uninstall functionality for DatabaseClient so that all connections are terminated at the DatabaseService. -- Added the ability for a DatabaseService to terminate a connection. -- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used. -- Added additional show functions to enable connection inspection. -- Updates to agent logging, to include the reward both per step and per episode. -- Introduced Developer CLI tools to assist with developing/debugging PrimAITE - - Can be enabled via `primaite dev-mode enable` - - Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located -- Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization. -- Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training. - -## [Unreleased] -- Made requests fail to reach their target if the node is off -- Added responses to requests -- Made environment reset completely recreate the game object. -- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack. -- Changed the data manipulation scenario to include a second green agent on client 1. -- Refactored actions and observations to be configurable via object name, instead of UUID. -- Made database patch correctly take 2 timesteps instead of being immediate -- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING -- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function. -- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config. -- Made observation space flattening optional (on by default). To turn off for an agent, change the `agent_settings.flatten_obs` setting in the config. -- Added support for SQL INSERT command. -- Added ability to log each agent's action choices in each step to a JSON file. -- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present. -- Added NMAP application to all host and layer-3 network nodes. - -### Bug Fixes - -- ACL rules were not resetting on episode reset. -- ACLs were not showing up correctly in the observation space. -- Blue agent's ACL actions were being applied against the wrong IP addresses -- Deleted files and folders did not reset correctly on episode reset. -- Service health status was using the actual health state instead of the visible health state -- Database file health status was using the incorrect value for negative rewards -- Preventing file actions from reaching their intended file -- The data manipulation attack was triggered at episode start. -- FTP STOR stored an additional copy on the client machine's filesystem -- The red agent acted to early -- Order of service health state -- Starting a node didn't start the services on it -- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off -- The use of NODE_FILE_CHECKHASH and NODE_FOLDER_CHECKHASH in the current release is marked as 'Not Implemented'. - +## [3.1.0] - 2024-06-25 ### Added -- Network Hardware - Added base hardware module with NIC, SwitchPort, Node, and Link. Nodes have -fundamental services like ARP, ICMP, and PCAP running them by default. -- Network Transmission - Modelled OSI Model layers 1 through to 5 with various classes for creating network frames and -transmitting them from a Service/Application, down through the layers, over the wire, and back up through the layers to -a Service/Application another machine. -- Introduced `Router` and `Switch` classes to manage networking routes more effectively. - - Added `ACLRule` and `RouteTableEntry` classes as part of the `Router`. -- New `.show()` methods in all network component classes to inspect the state in either plain text or markdown formats. -- Added `Computer` and `Server` class to better differentiate types of network nodes. -- Integrated a new Use Case 2 network into the system. -- New unit tests to verify routing between different subnets using `.ping()`. -- system - Added the core structure of Application, Services, and Components. Also added a SoftwareManager and -SessionManager. -- Permission System - each action can define criteria that will be used to permit or deny agent actions. -- File System - ability to emulate a node's file system during a simulation -- Example notebooks - There are 5 jupyter notebook which walk through using PrimAITE - 1. Training a Stable Baselines 3 agent - 2. Training a single agent system using Ray RLLib - 3. Training a multi-agent system Ray RLLib - 4. Data manipulation end to end demonstration - 5. Data manipulation scenario with customised red agents -- Database: - - `DatabaseClient` and `DatabaseService` created to allow emulation of database actions - - Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup -- Red Agent Services: - - Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database). The attack runs in stages with a random, configurable probability of succeeding. - - `DataManipulationAgent` runs the Data Manipulator Bot according to a configured start step, frequency and variance. -- DNS Services: `DNSClient` and `DNSServer` -- FTP Services: `FTPClient` and `FTPServer` -- HTTP Services: `WebBrowser` to simulate a web client and `WebServer` -- NTP Services: `NTPClient` and `NTPServer` -- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic. - - **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required. - - **Enhanced Frame Reception**: The `receive_frame` method in `RouterNIC` is tailored to handle frames based on Layer 2 (Ethernet) checks, focusing on MAC address-based routing and broadcast frame acceptance. -- **Subnet-Wide Broadcasting for Services and Applications**: Implemented the ability for services and applications to conduct broadcasts across an entire IPv4 subnet within the network simulation framework. -- Introduced the `NetworkInterface` abstract class to provide a common interface for all network interfaces. Subclasses are divided into two main categories: `WiredNetworkInterface` and `WirelessNetworkInterface`, each serving as an abstract base class (ABC) for more specific interface types. Under `WiredNetworkInterface`, the subclasses `NIC` and `SwitchPort` were added. For wireless interfaces, `WirelessNIC` and `WirelessAccessPoint` are the subclasses under `WirelessNetworkInterface`. -- Added `Layer3Interface` as an abstract base class for networking functionalities at layer 3, including IP addressing and routing capabilities. This class is inherited by `NIC`, `WirelessNIC`, and `WirelessAccessPoint` to provide them with layer 3 capabilities, facilitating their role in both wired and wireless networking contexts with IP-based communication. -- Created the `ARP` and `ICMP` service classes to handle Address Resolution Protocol operations and Internet Control Message Protocol messages, respectively, with `RouterARP` and `RouterICMP` for router-specific implementations. -- Created `HostNode` as a subclass of `Node`, extending its functionality with host-specific services and applications. This class is designed to represent end-user devices like computers or servers that can initiate and respond to network communications. -- Introduced a new `IPV4Address` type in the Pydantic model for enhanced validation and auto-conversion of IPv4 addresses from strings using an `ipv4_validator`. -- Comprehensive documentation for the Node and its network interfaces, detailing the operational workflow from frame reception to application-level processing. -- Detailed descriptions of the Session Manager and Software Manager functionalities, including their roles in managing sessions, software services, and applications within the simulation. -- Documentation for the Packet Capture (PCAP) service and SysLog functionality, highlighting their importance in logging network frames and system events, respectively. -- Expanded documentation on network devices such as Routers, Switches, Computers, and Switch Nodes, explaining their specific processing logic and protocol support. -- **Firewall Node**: Introduced the `Firewall` class extending the functionality of the existing `Router` class. The `Firewall` class incorporates advanced features to scrutinize, direct, and filter traffic between various network zones, guided by predefined security rules and policies. Key functionalities include: - - Access Control Lists (ACLs) for traffic filtering based on IP addresses, protocols, and port numbers. - - Network zone segmentation for managing traffic across external, internal, and DMZ (De-Militarized Zone) networks. - - Interface configuration to establish connectivity and define network parameters for external, internal, and DMZ interfaces. - - Protocol and service management to oversee traffic and enforce security policies. - - Dynamic traffic processing and filtering to ensure network security and integrity. -- `AirSpace` class to simulate wireless communications, managing wireless interfaces and facilitating the transmission of frames within specified frequencies. -- `AirSpaceFrequency` enum for defining standard wireless frequencies, including 2.4 GHz and 5 GHz bands, to support realistic wireless network simulations. -- `WirelessRouter` class, extending the `Router` class, to incorporate wireless networking capabilities alongside traditional wired connections. This class allows the configuration of wireless access points with specific IP settings and operating frequencies. -- Documentation Updates: - - Examples include how to set up PrimAITE session via config - - Examples include how to create nodes and install software via config - - Examples include how to set up PrimAITE session via Python - - Examples include how to create nodes and install software via Python - - Added missing ``DoSBot`` documentation page - - Added diagrams where needed to make understanding some things easier - - Templated parts of the documentation to prevent unnecessary repetition and for easier maintaining of documentation - - Separated documentation pages of some items i.e. client and server software were on the same pages - which may make things confusing - - Configuration section at the bottom of the software pages specifying the configuration options available (and which ones are optional) -- Ability to add ``Firewall`` node via config -- Ability to add ``Router`` routes via config -- Ability to add ``Router``/``Firewall`` ``ACLRule`` via config -- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events. -- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE". -- Router-specific SessionManager Implementation: Introduced a specialized version of the SessionManager tailored for router operations. This enhancement enables the SessionManager to determine the routing path by consulting the route table. +- Observations for traffic amounts on host network interfaces +- NMAP application network discovery, including ping scan and port scan +- NMAP actions +- Automated adding copyright notices to source files +- More file types +- `show` method to files +- `model_dump` methods to network enums to enable better logging ### Changed -- Integrated the RouteTable into the Routers frame processing. -- Frames are now dropped when their TTL reaches 0 -- **NIC Functionality Update**: Updated the Network Interface Card (`NIC`) functionality to support Layer 3 (L3) broadcasts. - - **Layer 3 Broadcast Handling**: Enhanced the existing `NIC` classes to correctly process and handle Layer 3 broadcasts. This update allows devices using standard NICs to effectively participate in network activities that involve L3 broadcasting. - - **Improved Frame Reception Logic**: The `receive_frame` method of the `NIC` class has been updated to include additional checks and handling for L3 broadcasts, ensuring proper frame processing in a wider range of network scenarios. -- Standardised the way network interfaces are accessed across all `Node` subclasses (`HostNode`, `Router`, `Switch`) by maintaining a comprehensive `network_interface` attribute. This attribute captures all network interfaces by their port number, streamlining the management and interaction with network interfaces across different types of nodes. -- Refactored all tests to utilise new `Node` subclasses (`Computer`, `Server`, `Router`, `Switch`) instead of creating generic `Node` instances and manually adding network interfaces. This change aligns test setups more closely with the intended use cases and hierarchies within the network simulation framework. -- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios. -- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules. -- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them. -- Integration of NMNE capturing functionality within the `NICObservation` class. -- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario +- Updated file system actions to stop failures when creating duplicate files +- Improved parsing of ACL add rule actions to make some parameters optional + +### Fixed +- Fixed database client uninstall failing due to persistent connections +- Fixed packet storm when pinging broadcast addresses + + +## [3.0.0] - 2024-06-10 + +### Added +- New simulation module +- Multi agent reinforcement learning support +- File system class to manage files and folders +- Software for nodes that can have its own behaviour +- Software classes to model FTP, Postgres databases, web traffic, NTP +- Much more detailed network simulation including packets, links, and network interfaces +- More node types: host, computer, server, router, switch, wireless router, and firewalls +- Network Hardware - NIC, SwitchPort, Node, and Link. Nodes have fundamental services like ARP, ICMP, and PCAP running them by default. +- Malicious network event detection +- New `game` module for managing agents +- ACL rule wildcard masking +- Network broadcasting +- Wireless transmission +- More detailed documentation +- Example jupyter notebooks to demonstrate new functionality +- More reward components +- Packet capture logs +- Node system logs +- Per-step full simulation state log +- Attack randomisation with respect to timing and attack source +- Ability to set log level via CLI +- Ability to vary the YAML configuration per-episode +- Developer CLI tools for enhanced debugging (with `primaite dev-mode enable`) +- `show` function to many simulation objects to inspect their current state + +### Changed +- Decoupled the environment from the simulation by adding the `game` interface layer +- Made agents share a common base class +- Added more actions +- Made all agents use CAOS actions, including red and green agents +- Reworked YAML configuration file schema +- Reworked the reward system to be component-based +- Changed agent logs to create a JSON output instead of CSV with more detailed action information +- Made observation space flattening optional +- Made all logging optional +- Agent actions now provide responses with a success code ### Removed -- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol` -- Removed legacy training modules -- Removed tests for legacy code +- Legacy simulation modules +- Legacy training modules +- Tests for legacy code +- Hardcoded IERs and PoL, traffic generation is now handled by agents and software +- Inbuilt agent training scripts -### Fixed -- Addressed network transmission issues that previously allowed ARP requests to be incorrectly routed and repeated across different subnets. This fix ensures ARP requests are correctly managed and confined to their appropriate network segments. -- Resolved problems in `Node` and its subclasses where the default gateway configuration was not properly utilized for communications across different subnets. This correction ensures that nodes effectively use their configured default gateways for outbound communications to other network segments, thereby enhancing the network's routing functionality and reliability. -- Network Interface Port name/num being set properly for sys log and PCAP output. ## [2.0.0] - 2023-07-26 ### Added -- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE. -- Application Directories to enable PrimAITE as a Python package with predefined directories for storage. -- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib. -- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`. -- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options. -- Session loading to revisit previously run sessions for SB3 Agents. -- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface. -- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs: - 1. Session Metadata - 2. Results - 3. Diagrams - 4. Saved agents (training checkpoints and a final trained agent). -- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup. -- Benchmarking of PrimAITE performance, showcasing session and step durations for reference. -- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`. +- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE. +- Application Directories to enable PrimAITE as a Python package with predefined directories for storage. +- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib. +- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`. +- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options. +- Session loading to revisit previously run sessions for SB3 Agents. +- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface. +- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs: Session Metadata, Results, Diagrams, Trained agents. +- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup. +- Benchmarking of PrimAITE performance, showcasing session and step durations for reference. +- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`. ### Changed -- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions. -- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`. -- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names. -- Docs and Tests now sit outside the `src/` directory. -- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages. -- All dependencies are now defined in the `pyproject.toml` file. -- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each. -- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management. -- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations. -- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority. +- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions. +- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`. +- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names. +- Docs and Tests now sit outside the `src/` directory. +- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages. +- All dependencies are now defined in the `pyproject.toml` file. +- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each. +- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management. +- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations. +- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority. ### Fixed -- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation. -- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes. -- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands. -- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making. +- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation. +- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes. +- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands. +- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making. + ## [1.1.1] - 2023-06-27 -### Bug Fixes -* Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by: - * Implementing a reference copy of all green IERs (`self.green_iers_reference`). - * Clearing the traffic on reference IERs at the same time as the live IERs. - * Passing the `green_iers_reference` to the `apply_iers` function at the reference stage. - * Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`. - * Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment. - * Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes. - * Removing the unnecessary "Reapply PoL and IERs" action from the step function. - * Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function. +### Fixed +- Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by: + - Implementing a reference copy of all green IERs (`self.green_iers_reference`). + - Clearing the traffic on reference IERs at the same time as the live IERs. + - Passing the `green_iers_reference` to the `apply_iers` function at the reference stage. + - Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`. + - Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment. + - Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes. + - Removing the unnecessary "Reapply PoL and IERs" action from the step function. + - Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function. + ## [1.1.0] - 2023-03-13 ### Added -* The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function. -* The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name. -* Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components: - * Blue agent observation space; - * Blue agent action space; - * Reward function; - * Node pattern-of-life. -* The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network). -* New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile. -* NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability. -* The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value. -* The User Guide has been updated to reflect all the above changes. +- The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function. +- The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name. +- Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components: + - Blue agent observation space; + - Blue agent action space; + - Reward function; + - Node pattern-of-life. +- The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network). +- New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile. +- NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability. +- The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value. +- The User Guide has been updated to reflect all the above changes. ### Changed -* "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing. -* "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement. -* "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement. -* "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file. -* Updates to Transactions. +- "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing. +- "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement. +- "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement. +- "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file. +- Updates to Transactions. ### Fixed -* Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default. - - - -[unreleased]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/compare/v2.0.0...HEAD -[2.0.0]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/tag/v2.0.0 +- Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bf5e75e4..dc10edbb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,9 +13,6 @@ * [Fork the repository](https://github.com/{todo:fill in URL}/PrimAITE/fork). * Install the pre-commit hook with `pre-commit install`. * Implement the bug fix. -* Update documentation where applicable. -* Update the **UNRELEASED** section of the [CHANGELOG.md](CHANGELOG.md) file -* Write a suitable test/tests. * Commit the bug fix to the dev branch on your fork. If the bug has an open issue under [Issues](https://github.com/{todo:fill in URL}/PrimAITE/issues), reference the issue in the commit message (e.g. #1 references issue 1). * Submit a pull request from your dev branch to the {todo:fill in URL}/PrimAITE dev branch. Again, if the bug has an open issue under [Issues](https://github.com/{todo:fill in URL}/PrimAITE/issues), reference the issue in the pull request description. diff --git a/_config.yml b/_config.yml new file mode 100644 index 00000000..b4654829 --- /dev/null +++ b/_config.yml @@ -0,0 +1,3 @@ +# Used by nbmake to change build pipeline notebook timeout +execute: + timeout: 600 diff --git a/pyproject.toml b/pyproject.toml index 9e919604..e29fd504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ - "ray[rllib] >= 2.20.0, < 3", + "ray[rllib] >= 2.20.0, < 2.33", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", "sb3-contrib==2.1.0", diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 9a5fedc9..7263cfc1 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -294,7 +294,7 @@ class ConfigureDoSBotAction(AbstractAction): """Action which sets config parameters for a DoS bot on a node.""" class _Opts(BaseModel): - """Schema for options that can be passed to this option.""" + """Schema for options that can be passed to this action.""" model_config = ConfigDict(extra="forbid") target_ip_address: Optional[str] = None diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index 8811bb15..ba70f2b4 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -101,7 +101,6 @@ "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "import yaml\n", - "from ray import air, tune\n", "from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n", "from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n" ] @@ -135,8 +134,7 @@ " .training(train_batch_size=128)\n", ")\n", "algo = config.build()\n", - "for i in range(2):\n", - " results = algo.train()" + "results = algo.train()" ] }, { @@ -191,8 +189,7 @@ " .training(train_batch_size=128)\n", ")\n", "algo = config.build()\n", - "for i in range(2):\n", - " results = algo.train()" + "results = algo.train()" ] } ], @@ -212,7 +209,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index c185b8b5..28f08edd 100644 --- a/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -24,14 +24,11 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.game.game import PrimaiteGame\n", "import yaml\n", "\n", - "from primaite.session.ray_envs import PrimaiteRayEnv\n", "from primaite import PRIMAITE_PATHS\n", "\n", "import ray\n", - "from ray import air, tune\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "from primaite.session.ray_envs import PrimaiteRayMARLEnv\n", "\n", @@ -72,7 +69,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Set training parameters and start the training\n", + "#### Start the training\n", "This example will save outputs to a default Ray directory and use mostly default settings." ] }, @@ -82,13 +79,8 @@ "metadata": {}, "outputs": [], "source": [ - "tune.Tuner(\n", - " \"PPO\",\n", - " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 5 * 128},\n", - " ),\n", - " param_space=config\n", - ").fit()" + "algo = config.build()\n", + "results = algo.train()" ] } ], diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index bdd60f36..9d870192 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -17,12 +17,10 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.game.game import PrimaiteGame\n", "import yaml\n", "from primaite.config.load import data_manipulation_config_path\n", "\n", "from primaite.session.ray_envs import PrimaiteRayEnv\n", - "from ray import air, tune\n", "import ray\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "\n", @@ -64,7 +62,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Set training parameters and start the training" + "#### Start the training" ] }, { @@ -73,13 +71,8 @@ "metadata": {}, "outputs": [], "source": [ - "tune.Tuner(\n", - " \"PPO\",\n", - " run_config=air.RunConfig(\n", - " stop={\"timesteps_total\": 512}\n", - " ),\n", - " param_space=config\n", - ").fit()\n" + "algo = config.build()\n", + "results = algo.train()\n" ] } ], diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 15c44821..7a127601 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -984,7 +984,7 @@ class Node(SimComponent): application_name = request[0] if self.software_manager.software.get(application_name): self.sys_log.warning(f"Can't install {application_name}. It's already installed.") - return RequestResponse.from_bool(False) + return RequestResponse(status="success", data={"reason": "already installed"}) application_class = Application._application_registry[application_name] self.software_manager.install(application_class) application_instance = self.software_manager.software.get(application_name) diff --git a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py index ae4825ff..dd38fafd 100644 --- a/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py +++ b/tests/integration_tests/configuration_file_parsing/test_software_fix_duration.py @@ -45,7 +45,7 @@ def test_fix_duration_set_from_config(): client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") # in config - services take 3 timesteps to fix - for service in SERVICE_TYPES_MAPPING: + for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]: assert client_1.software_manager.software.get(service) is not None assert client_1.software_manager.software.get(service).fixing_duration == 3 @@ -53,7 +53,7 @@ def test_fix_duration_set_from_config(): # remove test applications from list applications = set(Application._application_registry) - set(TestApplications) - for application in applications: + for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]: assert client_1.software_manager.software.get(application) is not None assert client_1.software_manager.software.get(application).fixing_duration == 1 @@ -64,17 +64,13 @@ def test_fix_duration_for_one_item(): client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") # in config - services take 3 timesteps to fix - services = copy.copy(SERVICE_TYPES_MAPPING) - services.pop("DatabaseService") - for service in services: + for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]: assert client_1.software_manager.software.get(service) is not None assert client_1.software_manager.software.get(service).fixing_duration == 2 # in config - applications take 1 timestep to fix # remove test applications from list - applications = set(Application._application_registry) - set(TestApplications) - applications.remove("DatabaseClient") - for applications in applications: + for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]: assert client_1.software_manager.software.get(applications) is not None assert client_1.software_manager.software.get(applications).fixing_duration == 2 From 19d7774440c2e11b5bfef3fc55a6daa3bb40c88a Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 1 Aug 2024 12:34:21 +0100 Subject: [PATCH 61/95] #2706 - Changed how Terminal Class handles its connections. Terminal now has a list of TerminalClientConnection objects that holds all active connections. Corrected a typo in ssh.py --- .../simulator/network/protocols/ssh.py | 2 +- .../system/services/terminal/terminal.py | 74 +++++++++++-------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 8671a1c8..495a2a2b 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -22,7 +22,7 @@ class SSHTransportMessage(IntEnum): """Indicates User Authentication failed.""" SSH_MSG_USERAUTH_SUCCESS = 52 - """Indicates User Authentication failed was successful.""" + """Indicates User Authentication was successful.""" SSH_MSG_SERVICE_REQUEST = 24 """Requests a service - such as executing a command.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 6df21618..998238a9 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -43,6 +43,17 @@ class TerminalClientConnection(BaseModel): _connection_uuid: str = None """Connection UUID""" + @property + def is_local(self) -> bool: + """Indicates if connection is remote or local. + + Returns True if local, False if remote. + """ + for interface in self.parent_node.network_interface: + if self.dest_ip_address == self.parent_node.network_interface[interface].ip_address: + return True + return False + @property def client(self) -> Optional[Terminal]: """The Terminal that holds this connection.""" @@ -57,9 +68,6 @@ class TerminalClientConnection(BaseModel): class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - is_connected: bool = False - "Boolean Value for whether connected" - connection_uuid: Optional[str] = None "Uuid for connection requests" @@ -69,7 +77,8 @@ class Terminal(Service): health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD "Service Health State" - remote_connection: Dict[str, TerminalClientConnection] = {} + _connections: Dict[str, TerminalClientConnection] = {} + "List of active connections held on this terminal." def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -95,13 +104,13 @@ class Terminal(Service): :param markdown: Whether to display the table in Markdown format or not. Default is `False`. """ - table = PrettyTable(["Connection ID", "IP_Address", "Active"]) + table = PrettyTable(["Connection ID", "IP_Address", "Active", "Local"]) if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.sys_log.hostname} {self.name} Remote Connections" - for connection_id, connection in self.remote_connection.items(): - table.add_row([connection_id, connection.dest_ip_address, connection.is_active]) + table.title = f"{self.sys_log.hostname} {self.name} Connections" + for connection_id, connection in self._connections.items(): + table.add_row([connection_id, connection.dest_ip_address, connection.is_active, connection.is_local]) print(table.get_string(sortby="Connection ID")) def _init_request_manager(self) -> RequestManager: @@ -182,11 +191,18 @@ class Terminal(Service): """Message that is reported when a request is rejected by this validator.""" return "Cannot perform request on terminal as not logged in." + def _add_new_connection(self, connection_uuid: str, dest_ip_address: IPv4Address): + """Create a new connection object and amend to list of active connections.""" + self._connections[connection_uuid] = TerminalClientConnection( + parent_node=self.software_manager.node, + dest_ip_address=dest_ip_address, + connection_uuid=connection_uuid, + ) + def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: """Process User request to login to Terminal. - If ip_address is passed, login will attempt a remote login to the terminal - + If ip_address is passed, login will attempt a remote login to the node at that address. :param username: The username credential. :param password: The user password component of credentials. :param dest_ip_address: The IP address of the node we want to connect to. @@ -198,8 +214,6 @@ class Terminal(Service): if ip_address: # if ip_address has been provided, we assume we are logging in to a remote terminal. - if ip_address == self.parent.network_interface[1].ip_address: - return self._process_local_login(username=username, password=password) return self._send_remote_login(username=username, password=password, ip_address=ip_address) return self._process_local_login(username=username, password=password) @@ -207,11 +221,14 @@ class Terminal(Service): def _process_local_login(self, username: str, password: str) -> bool: """Local session login to terminal.""" # TODO: Un-comment this when UserSessionManager is merged. - # self.connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) - self.connection_uuid = str(uuid4()) - self.is_connected = True - if self.connection_uuid: + # connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) + connection_uuid = str(uuid4()) + if connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") + # Add new local session to list of connections + self._add_new_connection( + connection_uuid=connection_uuid, dest_ip_address=self.parent.network_interface[1].ip_address + ) return True else: self.sys_log.warning("Login failed, incorrect Username or Password") @@ -246,11 +263,10 @@ class Terminal(Service): # TODO: Un-comment this when UserSessionManager is merged. # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) connection_uuid = str(uuid4()) - self.is_connected = True if connection_uuid: # Send uuid to remote self.sys_log.info( - f"Remote login authorised, connection ID {self.connection_uuid} for " + f"Remote login authorised, connection ID {connection_uuid} for " f"{username} on {payload.sender_ip_address}" ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS @@ -262,12 +278,7 @@ class Terminal(Service): sender_ip_address=self.parent.network_interface[1].ip_address, target_ip_address=payload.sender_ip_address, ) - - self.remote_connection[connection_uuid] = TerminalClientConnection( - parent_node=self.software_manager.node, - dest_ip_address=payload.sender_ip_address, - connection_uuid=connection_uuid, - ) + self._add_new_connection(connection_uuid=connection_uuid, dest_ip_address=payload.sender_ip_address) self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) return True @@ -280,7 +291,7 @@ class Terminal(Service): """Receive Payload and process for a response. :param payload: The message contents received. - :return: True if successfull, else False. + :return: True if successful, else False. """ self.sys_log.debug(f"Received payload: {payload}") @@ -304,7 +315,6 @@ class Terminal(Service): elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") self.connection_uuid = payload.connection_uuid - self.is_connected = True return True elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: @@ -322,14 +332,15 @@ class Terminal(Service): def _disconnect(self, dest_ip_address: IPv4Address) -> bool: """Disconnect from the remote.""" - if not self.is_connected: - self.sys_log.warning("Not currently connected to remote") - return False - - if not self.remote_connection: + if not self._connections: self.sys_log.warning("No remote connection present") return False + # TODO: This should probably be done entirely by connection uuid and not IP_address. + for connection in self._connections: + if dest_ip_address == self._connections[connection].dest_ip_address: + self._connections.pop(connection) + software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "disconnect", "connection_id": self.connection_uuid}, @@ -347,7 +358,6 @@ class Terminal(Service): :return: True if successful, False otherwise. """ self._disconnect(dest_ip_address=dest_ip_address) - self.is_connected = False def send( self, From 78ad95fcef835b2b62f03a3ab724cf564a2e400f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 1 Aug 2024 13:58:35 +0100 Subject: [PATCH 64/95] #2735 - fix up node request manager and system software --- .../simulator/network/hardware/base.py | 36 +++++++++---------- .../network/hardware/nodes/host/host_node.py | 22 +++++------- .../network/hardware/nodes/network/router.py | 10 ++++-- 3 files changed, 33 insertions(+), 35 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index cbe8db64..d2aa4604 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -6,7 +6,7 @@ import secrets from abc import ABC, abstractmethod from ipaddress import IPv4Address, IPv4Network from pathlib import Path -from typing import Any, Dict, List, Optional, TypeVar, Union +from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel, Field, validate_call @@ -39,7 +39,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.processes.process import Process from primaite.simulator.system.services.service import Service -from primaite.simulator.system.software import IOSoftware +from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values from primaite.utils.validators import IPV4Address @@ -897,6 +897,10 @@ class UserManager(Service): table.add_row([user.username, user.is_admin, user.disabled]) print(table.get_string(sortby="Username")) + def install(self) -> None: + """Setup default user during first-time installation.""" + self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True) + def _is_last_admin(self, username: str) -> bool: return username in self.admins and len(self.admins) == 1 @@ -1100,9 +1104,6 @@ class UserSessionManager(Service): This class handles authentication, session management, and session timeouts for users interacting with the Node. """ - node: Node - """The node associated with this UserSessionManager.""" - local_session: Optional[UserSession] = None """The current local user session, if any.""" @@ -1183,7 +1184,7 @@ class UserSessionManager(Service): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.node.hostname} User Sessions" + table.title = f"{self.parent.hostname} User Sessions" def _add_session_to_table(user_session: UserSession): """ @@ -1472,6 +1473,9 @@ class Node(SimComponent): red_scan_countdown: int = 0 "Time steps until reveal to red scan is complete." + SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {} + "Base system software that must be preinstalled." + def __init__(self, **kwargs): """ Initialize the Node with various components and managers. @@ -1496,21 +1500,10 @@ class Node(SimComponent): dns_server=kwargs.get("dns_server"), ) super().__init__(**kwargs) + self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager - self.software_manager.install(UserSessionManager, node=self) - self._request_manager.add_request( - "sessions", RequestType(func=self.user_session_manager._request_manager) - ) # noqa - - self.software_manager.install(UserManager) - self._request_manager.add_request("accounts", RequestType(func=self.user_manager._request_manager)) # noqa - - self.user_manager.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True) - - self._install_system_software() - @property def user_manager(self) -> UserManager: """The Nodes User Manager.""" @@ -1767,8 +1760,6 @@ class Node(SimComponent): "services": {svc.name: svc.describe_state() for svc in self.services.values()}, "process": {proc.name: proc.describe_state() for proc in self.processes.values()}, "revealed_to_red": self.revealed_to_red, - "user_manager": self.user_manager.describe_state(), - "user_session_manager": self.user_session_manager.describe_state(), } ) return state @@ -2134,6 +2125,11 @@ class Node(SimComponent): # for process_id in self.processes: # self.processes[process_id] + def _install_system_software(self) -> None: + """Preinstall required software.""" + for _, software_class in self.SYSTEM_SOFTWARE.items(): + self.software_manager.install(software_class) + def __contains__(self, item: Any) -> bool: if isinstance(item, Service): return item.uuid in self.services diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index aac57e95..22c50bef 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -5,7 +5,13 @@ from ipaddress import IPv4Address from typing import Any, ClassVar, Dict, Optional from primaite import getLogger -from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node +from primaite.simulator.network.hardware.base import ( + IPWiredNetworkInterface, + Link, + Node, + UserManager, + UserSessionManager, +) from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.applications.application import ApplicationOperatingState @@ -306,8 +312,8 @@ class HostNode(Node): "NTPClient": NTPClient, "WebBrowser": WebBrowser, "NMAP": NMAP, - # "UserSessionManager": UserSessionManager, - # "UserManager": UserManager, + "UserSessionManager": UserSessionManager, + "UserManager": UserManager, } """List of system software that is automatically installed on nodes.""" @@ -340,16 +346,6 @@ class HostNode(Node): """ return self.software_manager.software.get("ARP") - def _install_system_software(self): - """ - Installs the system software and network services typically found on an operating system. - - This method equips the host with essential network services and applications, preparing it for various - network-related tasks and operations. - """ - for _, software_class in self.SYSTEM_SOFTWARE.items(): - self.software_manager.install(software_class) - def default_gateway_hello(self): """ Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address. diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 61b7b96a..42821120 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -4,14 +4,14 @@ from __future__ import annotations import secrets from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable from pydantic import validate_call from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent -from primaite.simulator.network.hardware.base import IPWiredNetworkInterface +from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, UserManager, UserSessionManager from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode from primaite.simulator.network.protocols.arp import ARPPacket @@ -1200,6 +1200,11 @@ class Router(NetworkNode): RouteTable, RouterARP, and RouterICMP services. """ + SYSTEM_SOFTWARE: ClassVar[Dict] = { + "UserSessionManager": UserSessionManager, + "UserManager": UserManager, + } + num_ports: int network_interfaces: Dict[str, RouterInterface] = {} "The Router Interfaces on the node." @@ -1235,6 +1240,7 @@ class Router(NetworkNode): resolution within the network. These services are crucial for the router's operation, enabling it to manage network traffic efficiently. """ + super()._install_system_software() self.software_manager.install(RouterICMP) icmp: RouterICMP = self.software_manager.icmp # noqa icmp.router = self From 0fe61576c768839429a4802ba5ec89b4ac8f48ba Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 2 Aug 2024 09:13:31 +0100 Subject: [PATCH 65/95] #2706 - Removed source and target ip_address attributes from the SSHPacket Class. Terminal now uses session_id to send login outcome. No more network_interface[1].ip_address. --- .../simulator/network/protocols/ssh.py | 7 -- .../system/services/terminal/terminal.py | 112 +++++++++--------- 2 files changed, 53 insertions(+), 66 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 495a2a2b..4ec043b8 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -1,7 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from enum import IntEnum -from ipaddress import IPv4Address from typing import Optional from primaite.interface.request import RequestResponse @@ -68,12 +67,6 @@ class SSHUserCredentials(DataPacket): class SSHPacket(DataPacket): """Represents an SSHPacket.""" - sender_ip_address: IPv4Address - """Sender IP Address""" - - target_ip_address: IPv4Address - """Target IP Address""" - transport_message: SSHTransportMessage """Message Transport Type""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 998238a9..192f0551 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -28,32 +28,21 @@ class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. - This class is used to record current remote User Connections to the Terminal class. + This class is used to record current User Connections to the Terminal class. """ parent_node: Node # Technically should be HostNode but this causes circular import error. """The parent Node that this connection was created on.""" - is_active: bool = True - """Flag to state whether the connection is still active or not.""" - dest_ip_address: IPv4Address = None """Destination IP address of connection""" + session_id: str = None + """Session ID that connection is linked to""" + _connection_uuid: str = None """Connection UUID""" - @property - def is_local(self) -> bool: - """Indicates if connection is remote or local. - - Returns True if local, False if remote. - """ - for interface in self.parent_node.network_interface: - if self.dest_ip_address == self.parent_node.network_interface[interface].ip_address: - return True - return False - @property def client(self) -> Optional[Terminal]: """The Terminal that holds this connection.""" @@ -68,9 +57,6 @@ class TerminalClientConnection(BaseModel): class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - connection_uuid: Optional[str] = None - "Uuid for connection requests" - operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING "Initial Operating State" @@ -104,13 +90,13 @@ class Terminal(Service): :param markdown: Whether to display the table in Markdown format or not. Default is `False`. """ - table = PrettyTable(["Connection ID", "IP_Address", "Active", "Local"]) + table = PrettyTable(["Connection ID", "Session_ID"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.sys_log.hostname} {self.name} Connections" for connection_id, connection in self._connections.items(): - table.add_row([connection_id, connection.dest_ip_address, connection.is_active, connection.is_local]) + table.add_row([connection_id, connection.session_id]) print(table.get_string(sortby="Connection ID")) def _init_request_manager(self) -> RequestManager: @@ -145,11 +131,12 @@ class Terminal(Service): self.execute(command) return RequestResponse(status="success", data={}) - def _logoff() -> RequestResponse: + def _logoff(request: List[Any]) -> RequestResponse: """Logoff from connection.""" + connection_uuid = request[0] # TODO: Uncomment this when UserSessionManager merged. - # self.parent.UserSessionManager.logoff(self.connection_uuid) - self.disconnect(self.connection_uuid) + # self.parent.UserSessionManager.logoff(connection_uuid) + self.disconnect(connection_uuid) return RequestResponse(status="success", data={}) @@ -191,12 +178,12 @@ class Terminal(Service): """Message that is reported when a request is rejected by this validator.""" return "Cannot perform request on terminal as not logged in." - def _add_new_connection(self, connection_uuid: str, dest_ip_address: IPv4Address): + def _add_new_connection(self, connection_uuid: str, session_id: str): """Create a new connection object and amend to list of active connections.""" self._connections[connection_uuid] = TerminalClientConnection( parent_node=self.software_manager.node, - dest_ip_address=dest_ip_address, connection_uuid=connection_uuid, + session_id=session_id, ) def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: @@ -219,23 +206,31 @@ class Terminal(Service): return self._process_local_login(username=username, password=password) def _process_local_login(self, username: str, password: str) -> bool: - """Local session login to terminal.""" + """Local session login to terminal. + + :param username: Username for login. + :param password: Password for login. + :return: boolean, True if successful, else False + """ # TODO: Un-comment this when UserSessionManager is merged. # connection_uuid = self.parent.UserSessionManager.login(username=username, password=password) connection_uuid = str(uuid4()) if connection_uuid: - self.sys_log.info(f"Login request authorised, connection uuid: {self.connection_uuid}") + self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections - self._add_new_connection( - connection_uuid=connection_uuid, dest_ip_address=self.parent.network_interface[1].ip_address - ) + self._add_new_connection(connection_uuid=connection_uuid) return True else: self.sys_log.warning("Login failed, incorrect Username or Password") return False def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool: - """Attempt to login to a remote terminal.""" + """Attempt to login to a remote terminal. + + :param username: username for login. + :param password: password for login. + :ip_address: IP address of the target node for login. + """ transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password) @@ -244,14 +239,12 @@ class Terminal(Service): transport_message=transport_message, connection_message=connection_message, user_account=user_account, - target_ip_address=ip_address, - sender_ip_address=self.parent.network_interface[1].ip_address, ) self.sys_log.info(f"Sending remote login request to {ip_address}") return self.send(payload=payload, dest_ip_address=ip_address) - def _process_remote_login(self, payload: SSHPacket) -> bool: + def _process_remote_login(self, payload: SSHPacket, session_id: str) -> bool: """Processes a remote terminal requesting to login to this terminal. :param payload: The SSH Payload Packet. @@ -266,8 +259,7 @@ class Terminal(Service): if connection_uuid: # Send uuid to remote self.sys_log.info( - f"Remote login authorised, connection ID {connection_uuid} for " - f"{username} on {payload.sender_ip_address}" + f"Remote login authorised, connection ID {connection_uuid} for " f"{username} in session {session_id}" ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA @@ -275,19 +267,17 @@ class Terminal(Service): transport_message=transport_message, connection_message=connection_message, connection_uuid=connection_uuid, - sender_ip_address=self.parent.network_interface[1].ip_address, - target_ip_address=payload.sender_ip_address, ) - self._add_new_connection(connection_uuid=connection_uuid, dest_ip_address=payload.sender_ip_address) + self._add_new_connection(connection_uuid=connection_uuid, session_id=session_id) - self.send(payload=return_payload, dest_ip_address=return_payload.target_ip_address) + self.send(payload=return_payload, session_id=session_id) return True else: # UserSessionManager has returned None self.sys_log.warning("Login failed, incorrect Username or Password") return False - def receive(self, payload: SSHPacket, **kwargs) -> bool: + def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: """Receive Payload and process for a response. :param payload: The message contents received. @@ -310,11 +300,10 @@ class Terminal(Service): self.sys_log.debug(f"Disconnecting {connection_id}") elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - return self._process_remote_login(payload=payload) + return self._process_remote_login(payload=payload, session_id=session_id) elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") - self.connection_uuid = payload.connection_uuid return True elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: @@ -330,16 +319,18 @@ class Terminal(Service): """Execute a passed ssh command via the request manager.""" return self.parent.apply_request(command) - def _disconnect(self, dest_ip_address: IPv4Address) -> bool: - """Disconnect from the remote.""" + def _disconnect(self, connection_uuid: str) -> bool: + """Disconnect from the remote. + + :param connection_uuid: Connection ID that we want to disconnect. + :return True if successful, False otherwise. + """ if not self._connections: self.sys_log.warning("No remote connection present") return False - # TODO: This should probably be done entirely by connection uuid and not IP_address. - for connection in self._connections: - if dest_ip_address == self._connections[connection].dest_ip_address: - self._connections.pop(connection) + dest_ip_address = self._connections[connection_uuid].dest_ip_address + self._connections.pop(connection_uuid) software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( @@ -347,22 +338,23 @@ class Terminal(Service): dest_ip_address=dest_ip_address, dest_port=self.port, ) - self.connection_uuid = None - self.sys_log.info(f"{self.name}: Disconnected {self.connection_uuid}") + self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}") return True - def disconnect(self, dest_ip_address: IPv4Address) -> bool: - """Disconnect from remote connection. + def disconnect(self, connection_uuid: Optional[str]) -> bool: + """Disconnect the terminal. - :param dest_ip_address: The IP address fo the connection we are terminating. + If no connection id has been supplied, disconnects the first connection. + :param connection_uuid: Connection ID that we want to disconnect. :return: True if successful, False otherwise. """ - self._disconnect(dest_ip_address=dest_ip_address) + if not connection_uuid: + connection_uuid = next(iter(self._connections)) + + return self._disconnect(connection_uuid=connection_uuid) def send( - self, - payload: SSHPacket, - dest_ip_address: IPv4Address, + self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None ) -> bool: """ Send a payload out from the Terminal. @@ -374,4 +366,6 @@ class Terminal(Service): self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") return False self.sys_log.debug(f"Sending payload: {payload}") - return super().send(payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port) + return super().send( + payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id + ) From c2a19af6fa259d9cf4ac4525075b8b8900936ac3 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 2 Aug 2024 09:20:00 +0100 Subject: [PATCH 66/95] #2735 - added documentation for users, usermanager and usersessionmanager. Added the ability to add additional users from config and documented this. also tested additional users from config. --- .../nodes/common/common_node_attributes.rst | 27 +++ .../network/base_hardware.rst | 206 +++++++++++++++++- src/primaite/game/game.py | 8 +- .../assets/configs/basic_node_with_users.yaml | 34 +++ .../test_users_creation_from_config.py | 26 +++ 5 files changed, 298 insertions(+), 3 deletions(-) create mode 100644 tests/assets/configs/basic_node_with_users.yaml create mode 100644 tests/integration_tests/network/test_users_creation_from_config.py diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index e648e4a1..7cf11eb4 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -53,3 +53,30 @@ The number of time steps required to occur in order for the node to cycle from ` Optional. Default value is ``3``. The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``. + +``users`` +--------- + +The list of pre-existing users that are additional to the default admin user (``username=admin``, ``password=admin``). +Additional users are configured as an array nd must contain a ``username``, ``password``, and can contain an optional +boolean ``is_admin``. + +Example of adding two additional users to a node: + +.. code-block:: yaml + + simulation: + network: + nodes: + - hostname: client_1 + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + users: + - username: jane.doe + password: '1234' + is_admin: true + - username: john.doe + password: password_1 + is_admin: false diff --git a/docs/source/simulation_components/network/base_hardware.rst b/docs/source/simulation_components/network/base_hardware.rst index 9e42b1de..ce1e5c74 100644 --- a/docs/source/simulation_components/network/base_hardware.rst +++ b/docs/source/simulation_components/network/base_hardware.rst @@ -97,8 +97,8 @@ Node Behaviours/Functions - **receive_frame()**: Handles the processing of incoming network frames. - **apply_timestep()**: Advances the state of the node according to the simulation timestep. - **power_on()**: Initiates the node, enabling all connected Network Interfaces and starting all Services and - Applications, taking into account the `start_up_duration`. -- **power_off()**: Stops the node's operations, adhering to the `shut_down_duration`. + Applications, taking into account the ``start_up_duration``. +- **power_off()**: Stops the node's operations, adhering to the ``shut_down_duration``. - **ping()**: Sends ICMP echo requests to a specified IP address to test connectivity. - **has_enabled_network_interface()**: Checks if the node has any network interfaces enabled, facilitating network communication. @@ -109,3 +109,205 @@ Node Behaviours/Functions The Node class handles installation of system software, network connectivity, frame processing, system logging, and power states. It establishes baseline functionality while allowing subclassing to model specific node types like hosts, routers, firewalls etc. The flexible architecture enables composing complex network topologies. + +User, UserManager, and UserSessionManager +========================================= + +The ``base.py`` module also includes essential classes for managing users and their sessions within the PrimAITE +simulation. These are the ``User``, ``UserManager``, and ``UserSessionManager`` classes. The base ``Node`` class comes +with ``UserManager``, and ``UserSessionManager`` classes pre-installed. + +User Class +---------- + +The ``User`` class represents a user in the system. It includes attributes such as ``username``, ``password``, +``disabled``, and ``is_admin`` to define the user's credentials and status. + +Example Usage +^^^^^^^^^^^^^ + +Creating a user: + .. code-block:: python + + user = User(username="john_doe", password="12345") + +UserManager Class +----------------- + +The ``UserManager`` class handles user management tasks such as creating users, authenticating them, changing passwords, +and enabling or disabling user accounts. It maintains a dictionary of users and provides methods to manage them +effectively. + +Example Usage +^^^^^^^^^^^^^ + +Creating a ``UserManager`` instance and adding a user: + .. code-block:: python + + user_manager = UserManager() + user_manager.add_user(username="john_doe", password="12345") + +Authenticating a user: + .. code-block:: python + + user = user_manager.authenticate_user(username="john_doe", password="12345") + +UserSessionManager Class +------------------------ + +The ``UserSessionManager`` class manages user sessions, including local and remote sessions. It handles session creation, +timeouts, and provides methods for logging users in and out. + +Example Usage +^^^^^^^^^^^^^ + +Creating a ``UserSessionManager`` instance and logging a user in locally: + .. code-block:: python + + session_manager = UserSessionManager() + session_id = session_manager.local_login(username="john_doe", password="12345") + +Logging a user out: + .. code-block:: python + + session_manager.local_logout() + +Practical Examples +------------------ + +Below are unit tests which act as practical examples illustrating how to use the ``User``, ``UserManager``, and +``UserSessionManager`` classes within the context of a client-server network simulation. + +Setting up a Client-Server Network +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + from typing import Tuple + from uuid import uuid4 + + import pytest + + from primaite.simulator.network.container import Network + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.network.hardware.nodes.host.server import Server + + @pytest.fixture(scope="function") + def client_server_network() -> Tuple[Computer, Server, Network]: + network = Network() + + client = Computer( + hostname="client", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + client.power_on() + + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + network.connect(client.network_interface[1], server.network_interface[1]) + + return client, server, network + +Local Login Success +^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + def test_local_login_success(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + +Local Login Failure +^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + def test_local_login_failure(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert not client.user_session_manager.local_user_logged_in + +Adding a New User and Successful Local Login +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + def test_new_user_local_login_success(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + client.user_manager.add_user(username="jane.doe", password="12345") + + client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert client.user_session_manager.local_user_logged_in + +Clearing Previous Login on New Local Login +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + def test_new_local_login_clears_previous_login(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + current_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + client.user_manager.add_user(username="jane.doe", password="12345") + + new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "jane.doe" + + assert new_session_id != current_session_id + +Persistent Login for the Same User +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + def test_new_local_login_attempt_same_uses_persists(client_server_network): + client, server, network = client_server_network + + assert not client.user_session_manager.local_user_logged_in + + current_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + new_session_id = client.user_session_manager.local_login(username="admin", password="admin") + + assert client.user_session_manager.local_user_logged_in + + assert client.user_session_manager.local_session.user.username == "admin" + + assert new_session_id == current_session_id diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 5ef8c14c..68abf9f2 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.airspace import AirSpaceFrequency -from primaite.simulator.network.hardware.base import NodeOperatingState +from primaite.simulator.network.hardware.base import NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Printer, Server @@ -267,6 +267,7 @@ class PrimaiteGame: for node_cfg in nodes_cfg: n_type = node_cfg["type"] + new_node = None if n_type == "computer": new_node = Computer( hostname=node_cfg["hostname"], @@ -316,6 +317,11 @@ class PrimaiteGame: msg = f"invalid node type {n_type} in config" _LOGGER.error(msg) raise ValueError(msg) + + if "users" in node_cfg and new_node.software_manager.software.get("UserManager"): + user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa + for user_cfg in node_cfg["users"]: + user_manager.add_user(**user_cfg, bypass_can_perform_action=True) if "services" in node_cfg: for service_cfg in node_cfg["services"]: new_service = None diff --git a/tests/assets/configs/basic_node_with_users.yaml b/tests/assets/configs/basic_node_with_users.yaml new file mode 100644 index 00000000..064519dd --- /dev/null +++ b/tests/assets/configs/basic_node_with_users.yaml @@ -0,0 +1,34 @@ +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True + + +game: + max_episode_length: 256 + ports: + - ARP + protocols: + - ICMP + - UDP + + +simulation: + network: + nodes: + - hostname: client_1 + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + users: + - username: jane.doe + password: '1234' + is_admin: true + - username: john.doe + password: password_1 + is_admin: false diff --git a/tests/integration_tests/network/test_users_creation_from_config.py b/tests/integration_tests/network/test_users_creation_from_config.py new file mode 100644 index 00000000..8cd3b037 --- /dev/null +++ b/tests/integration_tests/network/test_users_creation_from_config.py @@ -0,0 +1,26 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import yaml + +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import UserManager +from tests import TEST_ASSETS_ROOT + + +def test_users_from_config(): + config_path = TEST_ASSETS_ROOT / "configs" / "basic_node_with_users.yaml" + + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + network = PrimaiteGame.from_config(cfg=config_dict).simulation.network + + client_1 = network.get_node_by_hostname("client_1") + + user_manager: UserManager = client_1.software_manager.software["UserManager"] + + assert len(user_manager.users) == 3 + + assert user_manager.users["jane.doe"].password == "1234" + assert user_manager.users["jane.doe"].is_admin + + assert user_manager.users["john.doe"].password == "password_1" + assert not user_manager.users["john.doe"].is_admin From ab4931463f211891efca84e082f9aab1ebb428ef Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 2 Aug 2024 09:21:55 +0100 Subject: [PATCH 67/95] #2706 - Minor change following the session_id changes as local_login failed to pass a session_id when creating a new TerminalClientConnection object --- src/primaite/simulator/system/services/terminal/terminal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 192f0551..92893b14 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -218,7 +218,8 @@ class Terminal(Service): if connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections - self._add_new_connection(connection_uuid=connection_uuid) + session_id = str(uuid4()) + self._add_new_connection(connection_uuid=connection_uuid, session_id=session_id) return True else: self.sys_log.warning("Login failed, incorrect Username or Password") From 5dcc0189a0655a47cd5e51dc17f98a06e890c117 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 2 Aug 2024 11:30:45 +0100 Subject: [PATCH 68/95] #2777: Implementation of RNG seed --- .../scripted_agents/probabilistic_agent.py | 14 ++++---- src/primaite/game/game.py | 2 ++ src/primaite/session/environment.py | 36 +++++++++++++++++++ src/primaite/session/ray_envs.py | 2 ++ 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index f5905ad0..ce1da3f2 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent): """Strict validation.""" action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" - random_seed: Optional[int] = None - """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. @@ -59,17 +57,19 @@ class ProbabilisticAgent(AbstractScriptedAgent): num_actions = len(action_space.action_map) settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} - # If seed not specified, set it to None so that numpy chooses a random one. - settings.setdefault("random_seed") - + # The random number seed for np.random is dependent on whether a random number seed is set + # in the config file. If there is one it is processed by set_random_seed() in environment.py + # and as a consequence the the sequence of rng_seed's used here will be repeatable. self.settings = ProbabilisticAgent.Settings(**settings) - - self.rng = np.random.default_rng(self.settings.random_seed) + rng_seed = np.random.randint(0, 65535) + self.rng = np.random.default_rng(rng_seed) + print(f"Probabilistic Agent - rng_seed: {rng_seed}") # convert probabilities from self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) + self.logger.info(f"ProbabilisticAgent RNG seed: {rng_seed}") def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 5ef8c14c..a4325b3e 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -70,6 +70,8 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") + seed: int = None + """Random number seed for RNGs.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[str] diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a87f0cde..359932c7 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,5 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK import json +import random +import sys from os import PathLike from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union @@ -17,6 +19,33 @@ from primaite.simulator.system.core.packet_capture import PacketCapture _LOGGER = getLogger(__name__) +# Check torch is installed +try: + import torch as th +except ModuleNotFoundError: + _LOGGER.debug("Torch not available for importing") + + +def set_random_seed(seed: int) -> Union[None, int]: + """ + Set random number generators. + + :param seed: int + """ + if seed is None or seed == -1: + return None + elif seed < -1: + raise ValueError("Invalid random number seed") + # Seed python RNG + random.seed(seed) + # Seed numpy RNG + np.random.seed(seed) + # Seed the RNG for all devices (both CPU and CUDA) + # if torch not installed don't set random seed. + if sys.modules["torch"]: + th.manual_seed(seed) + return seed + class PrimaiteGymEnv(gymnasium.Env): """ @@ -31,6 +60,9 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) """Object that returns a config corresponding to the current episode.""" + self.seed = self.episode_scheduler(0).get("game").get("seed") + """Get RNG seed from config file. NB: Must be before game instantiation.""" + self.seed = set_random_seed(self.seed) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) @@ -42,6 +74,8 @@ class PrimaiteGymEnv(gymnasium.Env): self.total_reward_per_episode: Dict[int, float] = {} """Average rewards of agents per episode.""" + _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -108,6 +142,8 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) + if seed is not None: + set_random_seed(seed) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: diff --git a/src/primaite/session/ray_envs.py b/src/primaite/session/ray_envs.py index 1adc324c..33c74b0e 100644 --- a/src/primaite/session/ray_envs.py +++ b/src/primaite/session/ray_envs.py @@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + super().reset() # Ensure PRNG seed is set everywhere rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()} _LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}") @@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" + super().reset() # Ensure PRNG seed is set everywhere if self.env.agent.action_masking: obs, *_ = self.env.reset(seed=seed) new_obs = {"action_mask": self.env.action_masks(), "observations": obs} From 61c7cc2da37eb8c16ac612cd8c7466dcc2c8c197 Mon Sep 17 00:00:00 2001 From: Christopher McCarthy Date: Fri, 2 Aug 2024 10:57:51 +0000 Subject: [PATCH 69/95] Apply suggestions from code review --- src/primaite/simulator/network/hardware/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index d2aa4604..c2b0ecc4 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1505,12 +1505,12 @@ class Node(SimComponent): self.session_manager.software_manager = self.software_manager @property - def user_manager(self) -> UserManager: + def user_manager(self) -> Optional[UserManager]: """The Nodes User Manager.""" return self.software_manager.software.get("UserManager") # noqa @property - def user_session_manager(self) -> UserSessionManager: + def user_session_manager(self) -> Optional[UserSessionManager]: """The Nodes User Session Manager.""" return self.software_manager.software.get("UserSessionManager") # noqa From 696236aa6162283f4cfa96e3642f36f1a43901d2 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 2 Aug 2024 12:47:02 +0100 Subject: [PATCH 70/95] #2735 - make the disabled/enabled admins/non-admins dynamic properties for simplicity. Added num_of_logins to User. Added additional test for counting user logins. Added all users to the UserManager describe_state function. Refactored model fields with empty dict as default value to have direct instantiation instead of using Field(default_factory=dict) or Field(default_factory=: lambda: {}). --- .../simulator/network/hardware/base.py | 55 +++++++++++++++---- .../test_user_session_manager_logins.py | 24 ++++++++ 2 files changed, 68 insertions(+), 11 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index c2b0ecc4..1d320824 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -817,6 +817,9 @@ class User(SimComponent): is_admin: bool = False """Boolean flag indicating whether the user has admin privileges""" + num_of_logins: int = 0 + """Counts the number of the User has logged in""" + def describe_state(self) -> Dict: """ Returns a dictionary representing the current state of the user. @@ -835,9 +838,7 @@ class UserManager(Service): :param disabled_admins: A dictionary of currently disabled admin users by their usernames """ - users: Dict[str, User] = Field(default_factory=dict) - admins: Dict[str, User] = Field(default_factory=dict) - disabled_admins: Dict[str, User] = Field(default_factory=dict) + users: Dict[str, User] = {} def __init__(self, **kwargs): """ @@ -880,6 +881,7 @@ class UserManager(Service): """ state = super().describe_state() state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)}) + state["users"] = {k: v.describe_state() for k, v in self.users.items()} return state def show(self, markdown: bool = False): @@ -897,6 +899,42 @@ class UserManager(Service): table.add_row([user.username, user.is_admin, user.disabled]) print(table.get_string(sortby="Username")) + @property + def non_admins(self) -> Dict[str, User]: + """ + Returns a dictionary of all enabled non-admin users. + + :return: A dictionary with usernames as keys and User objects as values for non-admin, non-disabled users. + """ + return {k: v for k, v in self.users.items() if not v.is_admin and not v.disabled} + + @property + def disabled_non_admins(self) -> Dict[str, User]: + """ + Returns a dictionary of all disabled non-admin users. + + :return: A dictionary with usernames as keys and User objects as values for non-admin, disabled users. + """ + return {k: v for k, v in self.users.items() if not v.is_admin and v.disabled} + + @property + def admins(self) -> Dict[str, User]: + """ + Returns a dictionary of all enabled admin users. + + :return: A dictionary with usernames as keys and User objects as values for admin, non-disabled users. + """ + return {k: v for k, v in self.users.items() if v.is_admin and not v.disabled} + + @property + def disabled_admins(self) -> Dict[str, User]: + """ + Returns a dictionary of all disabled admin users. + + :return: A dictionary with usernames as keys and User objects as values for admin, disabled users. + """ + return {k: v for k, v in self.users.items() if v.is_admin and v.disabled} + def install(self) -> None: """Setup default user during first-time installation.""" self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True) @@ -922,8 +960,6 @@ class UserManager(Service): return False user = User(username=username, password=password, is_admin=is_admin) self.users[username] = user - if is_admin: - self.admins[username] = user self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}") return True @@ -978,8 +1014,6 @@ class UserManager(Service): return False self.users[username].disabled = True self.sys_log.info(f"{self.name}: User disabled: {username}") - if username in self.admins: - self.disabled_admins[username] = self.admins.pop(username) return True self.sys_log.info(f"{self.name}: Failed to disable user: {username}") return False @@ -994,8 +1028,6 @@ class UserManager(Service): if username in self.users and self.users[username].disabled: self.users[username].disabled = False self.sys_log.info(f"{self.name}: User enabled: {username}") - if username in self.disabled_admins: - self.admins[username] = self.disabled_admins.pop(username) return True self.sys_log.info(f"{self.name}: Failed to enable user: {username}") return False @@ -1028,7 +1060,7 @@ class UserSession(SimComponent): """The timestep when the session ended, if applicable.""" local: bool = True - """Indicates if the session is local. Defaults to True.""" + """Indicates if the session is a local session or a remote session. Defaults to True as a local session.""" @classmethod def create(cls, user: User, timestep: int) -> UserSession: @@ -1041,6 +1073,7 @@ class UserSession(SimComponent): :param timestep: The timestep when the session is created. :return: An instance of UserSession. """ + user.num_of_logins += 1 return UserSession(user=user, start_step=timestep, last_active_step=timestep) def describe_state(self) -> Dict: @@ -1107,7 +1140,7 @@ class UserSessionManager(Service): local_session: Optional[UserSession] = None """The current local user session, if any.""" - remote_sessions: Dict[str, RemoteUserSession] = Field(default_factory=dict) + remote_sessions: Dict[str, RemoteUserSession] = {} """A dictionary of active remote user sessions.""" historic_sessions: List[UserSession] = Field(default_factory=list) diff --git a/tests/integration_tests/system/test_user_session_manager_logins.py b/tests/integration_tests/system/test_user_session_manager_logins.py index 955408ad..4318530c 100644 --- a/tests/integration_tests/system/test_user_session_manager_logins.py +++ b/tests/integration_tests/system/test_user_session_manager_logins.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.base import User from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server @@ -46,6 +47,29 @@ def test_local_login_success(client_server_network): assert client.user_session_manager.local_user_logged_in +def test_login_count_increases(client_server_network): + client, server, network = client_server_network + + admin_user: User = client.user_manager.users["admin"] + + assert admin_user.num_of_logins == 0 + + client.user_session_manager.local_login(username="admin", password="admin") + + assert admin_user.num_of_logins == 1 + + client.user_session_manager.local_login(username="admin", password="admin") + + # shouldn't change as user is already logged in + assert admin_user.num_of_logins == 1 + + client.user_session_manager.local_logout() + + client.user_session_manager.local_login(username="admin", password="admin") + + assert admin_user.num_of_logins == 2 + + def test_local_login_failure(client_server_network): client, server, network = client_server_network From a1e1a17c2a9fe87099b8bfcd9e3c3c0eab3bc408 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 2 Aug 2024 12:49:17 +0100 Subject: [PATCH 71/95] #2777: Add RNG test --- .../game_layer/test_RNG_seed.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/integration_tests/game_layer/test_RNG_seed.py diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py new file mode 100644 index 00000000..c1bb7bb0 --- /dev/null +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -0,0 +1,43 @@ +from primaite.config.load import data_manipulation_config_path +from primaite.session.environment import PrimaiteGymEnv +from primaite.game.agent.interface import AgentHistoryItem +import yaml +from pprint import pprint +import pytest + +@pytest.fixture() +def create_env(): + with open(data_manipulation_config_path(), 'r') as f: + cfg = yaml.safe_load(f) + + env = PrimaiteGymEnv(env_config = cfg) + return env + +def test_rng_seed_set(create_env): + env = create_env + env.reset(seed=3) + for i in range(100): + env.step(0) + a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + env.reset(seed=3) + for i in range(100): + env.step(0) + b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + assert a==b + +def test_rng_seed_unset(create_env): + env = create_env + env.reset() + for i in range(100): + env.step(0) + a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + env.reset() + for i in range(100): + env.step(0) + b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + + assert a!=b + From 0cc724be605fff5e65a893d9ddebd5ed2517f342 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 2 Aug 2024 12:50:40 +0100 Subject: [PATCH 72/95] #2777: Updated CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cebc2569..7d7ba9c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Bandwidth Tracking**: Tracks data transmission across each frequency. - **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files. - **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file. - +- **Random Number Generator Seeding**: Added support for specifying a random number seed in the config file. ### Changed - **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`. From e132c52121a874d735118b03bc211431f9bcc8f0 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 2 Aug 2024 13:32:34 +0100 Subject: [PATCH 73/95] #2706 - Removed the LoginValidator. Will be handled by UserSessionManager. Updated some missing variables in method definitions/ --- .../system/services/terminal/terminal.py | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 92893b14..1b8497d0 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -8,8 +8,8 @@ from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel -from primaite.interface.request import RequestFormat, RequestResponse -from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType +from primaite.interface.request import RequestResponse +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.protocols.ssh import ( SSHConnectionMessage, @@ -50,7 +50,7 @@ class TerminalClientConnection(BaseModel): def disconnect(self): """Disconnect the connection.""" - if self.client and self.is_active: + if self.client: self.client._disconnect(self._connection_uuid) # noqa @@ -101,14 +101,10 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" - _login_valid = Terminal._LoginValidator(terminal=self) - rm = super()._init_request_manager() rm.add_request( "send", - request_type=RequestType( - func=lambda request, context: RequestResponse.from_bool(self.send()), validator=_login_valid - ), + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), ) def _login(request: List[Any], context: Any) -> RequestResponse: @@ -119,7 +115,7 @@ class Terminal(Service): return RequestResponse(status="failure", data={}) def _remote_login(request: List[Any], context: Any) -> RequestResponse: - login = self._process_remote_login(username=request[0], password=request[1], ip_address=request[2]) + login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) if login: return RequestResponse(status="success", data={}) else: @@ -152,32 +148,13 @@ class Terminal(Service): rm.add_request( "Execute", - request_type=RequestType(func=_execute_request, validator=_login_valid), + request_type=RequestType(func=_execute_request), ) - rm.add_request("Logoff", request_type=RequestType(func=_logoff, validator=_login_valid)) + rm.add_request("Logoff", request_type=RequestType(func=_logoff)) return rm - class _LoginValidator(RequestPermissionValidator): - """ - When requests come in, this validator will only allow them through if the User is logged into the Terminal. - - Login is required before making use of the Terminal. - """ - - terminal: Terminal - """Save a reference to the Terminal instance.""" - - def __call__(self, request: RequestFormat, context: Dict) -> bool: - """Return whether the Terminal is connected.""" - return self.terminal.is_connected - - @property - def fail_message(self) -> str: - """Message that is reported when a request is rejected by this validator.""" - return "Cannot perform request on terminal as not logged in." - def _add_new_connection(self, connection_uuid: str, session_id: str): """Create a new connection object and amend to list of active connections.""" self._connections[connection_uuid] = TerminalClientConnection( @@ -249,6 +226,7 @@ class Terminal(Service): """Processes a remote terminal requesting to login to this terminal. :param payload: The SSH Payload Packet. + :param session_id: Session ID for sending login response. :return: True if successful, else False. """ username: str = payload.user_account.username @@ -282,6 +260,7 @@ class Terminal(Service): """Receive Payload and process for a response. :param payload: The message contents received. + :param session_id: Session ID of received message. :return: True if successful, else False. """ self.sys_log.debug(f"Received payload: {payload}") @@ -335,7 +314,7 @@ class Terminal(Service): software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": self.connection_uuid}, + payload={"type": "disconnect", "connection_id": connection_uuid}, dest_ip_address=dest_ip_address, dest_port=self.port, ) From 4bddf72cd335fd52da74cc193dbc1471cf111684 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 5 Aug 2024 09:29:17 +0100 Subject: [PATCH 74/95] #2706 - Initial refactor of Terminal Class following review discussion on Friday. Terminal will now return a TerminalConnection/RemoteTerminalConnection object on login. The new connection object can then be used to pass commands to the target node, without needing to form a full payload item. --- .../notebooks/Terminal-Processing.ipynb | 96 ++---- .../simulator/network/protocols/ssh.py | 2 + .../system/services/terminal/terminal.py | 293 +++++++++--------- .../_system/_services/test_terminal.py | 55 ++-- 4 files changed, 205 insertions(+), 241 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index fc795794..77be3822 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -26,7 +26,8 @@ "source": [ "from primaite.simulator.system.services.terminal.terminal import Terminal\n", "from primaite.simulator.network.container import Network\n", - "from primaite.simulator.network.hardware.nodes.host.computer import Computer" + "from primaite.simulator.network.hardware.nodes.host.computer import Computer\n", + "from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript" ] }, { @@ -83,7 +84,38 @@ "outputs": [], "source": [ "# Login to the remote (node_b) from local (node_a)\n", - "terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)" + "from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n", + "\n", + "\n", + "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer_b.software_manager.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(type(term_a_term_b_remote_connection))\n", + "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer_b.software_manager.show()" ] }, { @@ -109,45 +141,6 @@ "The Terminal can be used to send requests to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage\n", - "from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n", - "\n", - "computer_b.software_manager.show()\n", - "\n", - "payload: SSHPacket = SSHPacket(\n", - " payload=[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n", - " transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,\n", - " connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,\n", - " sender_ip_address=computer_a.network_interface[1].ip_address,\n", - " target_ip_address=computer_b.network_interface[1].ip_address,\n", - ")\n", - "\n", - "# Send command to install RansomwareScript\n", - "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `RansomwareScript` can then be seen in the list of applications on the `node_b Software Manager`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "computer_b.software_manager.show()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -157,27 +150,6 @@ "Here, we send a command to `computer_b` to create a new folder titled \"Downloads\"." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "computer_b.file_system.show()\n", - "\n", - "payload: SSHPacket = SSHPacket(\n", - " payload=[\"file_system\", \"create\", \"folder\", \"Downloads\"],\n", - " transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,\n", - " connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,\n", - " sender_ip_address=computer_a.network_interface[1].ip_address,\n", - " target_ip_address=computer_b.network_interface[1].ip_address,\n", - ")\n", - "\n", - "terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)\n", - "\n", - "computer_b.file_system.show()" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 4ec043b8..7ba629f8 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -76,6 +76,8 @@ class SSHPacket(DataPacket): user_account: Optional[SSHUserCredentials] = None """User Account Credentials if passed""" + connection_request_uuid: Optional[str] = None # Connection Request uuid. + connection_uuid: Optional[str] = None # The connection uuid used to validate the session ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 1b8497d0..b7bc5287 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -10,18 +10,11 @@ from pydantic import BaseModel from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.hardware.base import Node -from primaite.simulator.network.protocols.ssh import ( - SSHConnectionMessage, - SSHPacket, - SSHTransportMessage, - SSHUserCredentials, -) +from primaite.simulator.network.protocols.ssh import SSHPacket from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState -from primaite.simulator.system.software import SoftwareHealthState class TerminalClientConnection(BaseModel): @@ -31,40 +24,45 @@ class TerminalClientConnection(BaseModel): This class is used to record current User Connections to the Terminal class. """ - parent_node: Node # Technically should be HostNode but this causes circular import error. + parent_terminal: Terminal """The parent Node that this connection was created on.""" - dest_ip_address: IPv4Address = None - """Destination IP address of connection""" - session_id: str = None """Session ID that connection is linked to""" - _connection_uuid: str = None + connection_uuid: str = None """Connection UUID""" @property def client(self) -> Optional[Terminal]: """The Terminal that holds this connection.""" - return self.parent_node.software_manager.software.get("Terminal") + return self.parent_terminal - def disconnect(self): - """Disconnect the connection.""" - if self.client: - self.client._disconnect(self._connection_uuid) # noqa + def disconnect(self) -> bool: + """Disconnect the session.""" + return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid) + + +class RemoteTerminalConnection(TerminalClientConnection): + """ + RemoteTerminalConnection Class. + + This class acts as broker between the terminal and remote. + + """ + + def execute(self, command: Any) -> bool: + """Execute a given command on the remote Terminal.""" + if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING: + self.parent_terminal.sys_log.warning("Cannot process command as system not running") + # Send command to remote terminal to process. + return self.parent_terminal.send(payload=command, session_id=self.session_id) class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - operating_state: ServiceOperatingState = ServiceOperatingState.RUNNING - "Initial Operating State" - - health_state_actual: SoftwareHealthState = SoftwareHealthState.GOOD - "Service Health State" - - _connections: Dict[str, TerminalClientConnection] = {} - "List of active connections held on this terminal." + _client_connection_requests: Dict[str, Optional[str]] = {} def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -155,34 +153,40 @@ class Terminal(Service): return rm - def _add_new_connection(self, connection_uuid: str, session_id: str): + def execute(self, command: List[Any]) -> RequestResponse: + """Execute a passed ssh command via the request manager.""" + return self.parent.apply_request(command) + + def _create_local_connection(self, connection_uuid: str, session_id: str) -> RemoteTerminalConnection: """Create a new connection object and amend to list of active connections.""" - self._connections[connection_uuid] = TerminalClientConnection( - parent_node=self.software_manager.node, + new_connection = TerminalClientConnection( + parent_terminal=self, connection_uuid=connection_uuid, session_id=session_id, ) + self._connections[connection_uuid] = new_connection + self._client_connection_requests[connection_uuid] = new_connection - def login(self, username: str, password: str, ip_address: Optional[IPv4Address] = None) -> bool: - """Process User request to login to Terminal. + return new_connection - If ip_address is passed, login will attempt a remote login to the node at that address. - :param username: The username credential. - :param password: The user password component of credentials. - :param dest_ip_address: The IP address of the node we want to connect to. - :return: True if successful, False otherwise. - """ + def login( + self, username: str, password: str, ip_address: Optional[IPv4Address] = None + ) -> Optional[TerminalClientConnection]: + """Login to the terminal. Will attempt a remote login if ip_address is given, else local.""" if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning("Cannot process login as service is not running") - return False - + self.sys_log.warning("Cannot login as service is not running.") + return None + connection_request_id = str(uuid4()) + self._client_connection_requests[connection_request_id] = None if ip_address: - # if ip_address has been provided, we assume we are logging in to a remote terminal. - return self._send_remote_login(username=username, password=password, ip_address=ip_address) + # Assuming that if IP is passed we are connecting to remote + return self._send_remote_login( + username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id + ) + else: + return self._process_local_login(username=username, password=password) - return self._process_local_login(username=username, password=password) - - def _process_local_login(self, username: str, password: str) -> bool: + def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]: """Local session login to terminal. :param username: Username for login. @@ -195,110 +199,114 @@ class Terminal(Service): if connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections - session_id = str(uuid4()) - self._add_new_connection(connection_uuid=connection_uuid, session_id=session_id) - return True + self._create_local_connection(connection_uuid=connection_uuid, session_id="") + return TerminalClientConnection(parent_terminal=self, session_id="", connection_uuid=connection_uuid) else: self.sys_log.warning("Login failed, incorrect Username or Password") - return False + return None - def _send_remote_login(self, username: str, password: str, ip_address: IPv4Address) -> bool: - """Attempt to login to a remote terminal. + def _check_client_connection(self, connection_id: str) -> bool: + """Check that client_connection_id is valid.""" + return True if connection_id in self._client_connection_requests else False - :param username: username for login. - :param password: password for login. - :ip_address: IP address of the target node for login. - """ - transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST - connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA - user_account: SSHUserCredentials = SSHUserCredentials(username=username, password=password) + def _send_remote_login( + self, + username: str, + password: str, + ip_address: IPv4Address, + connection_request_id: str, + is_reattempt: bool = False, + ) -> Optional[RemoteTerminalConnection]: + """Process a remote login attempt.""" + self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}") + if is_reattempt: + valid_connection = self._check_client_connection(connection_id=connection_request_id) + if valid_connection: + remote_terminal_connection = self._client_connection_requests.pop(connection_request_id) + self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") + return remote_terminal_connection + else: + self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") + return None - payload: SSHPacket = SSHPacket( - transport_message=transport_message, - connection_message=connection_message, - user_account=user_account, + payload = { + "type": "login_request", + "username": username, + "password": password, + "connection_request_id": connection_request_id, + } + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_ip_address=ip_address, dest_port=self.port + ) + return self._send_remote_login( + username=username, + password=password, + ip_address=ip_address, + is_reattempt=True, + connection_request_id=connection_request_id, ) - self.sys_log.info(f"Sending remote login request to {ip_address}") - return self.send(payload=payload, dest_ip_address=ip_address) + def _create_remote_connection(self, connection_id: str, connection_request_id: str, session_id: str) -> None: + """Create a new TerminalClientConnection Object.""" + client_connection = RemoteTerminalConnection( + parent_terminal=self, session_id=session_id, connection_uuid=connection_id + ) + self._connections[connection_id] = client_connection + self._client_connection_requests[connection_request_id] = client_connection - def _process_remote_login(self, payload: SSHPacket, session_id: str) -> bool: - """Processes a remote terminal requesting to login to this terminal. - - :param payload: The SSH Payload Packet. - :param session_id: Session ID for sending login response. - :return: True if successful, else False. + def receive(self, session_id: str, payload: Any, **kwargs) -> bool: """ - username: str = payload.user_account.username - password: str = payload.user_account.password - self.sys_log.info(f"Sending UserAuth request to UserSessionManager, username={username}, password={password}") - # TODO: Un-comment this when UserSessionManager is merged. - # connection_uuid = self.parent.UserSessionManager.remote_login(username=username, password=password) - connection_uuid = str(uuid4()) - if connection_uuid: - # Send uuid to remote - self.sys_log.info( - f"Remote login authorised, connection ID {connection_uuid} for " f"{username} in session {session_id}" - ) - transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS - connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA - return_payload = SSHPacket( - transport_message=transport_message, - connection_message=connection_message, - connection_uuid=connection_uuid, - ) - self._add_new_connection(connection_uuid=connection_uuid, session_id=session_id) + Receive a payload from the Software Manager. - self.send(payload=return_payload, session_id=session_id) - return True - else: - # UserSessionManager has returned None - self.sys_log.warning("Login failed, incorrect Username or Password") - return False - - def receive(self, payload: SSHPacket, session_id: str, **kwargs) -> bool: - """Receive Payload and process for a response. - - :param payload: The message contents received. - :param session_id: Session ID of received message. - :return: True if successful, else False. + :param payload: A payload to receive. + :param session_id: The session id the payload relates to. + :return: True. """ - self.sys_log.debug(f"Received payload: {payload}") + self.sys_log.info(f"Received payload: {payload}") + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "login_request": + # add connection + connection_request_id = payload["connection_request_id"] + username = payload["username"] + password = payload["password"] + print(f"Connection ID is: {connection_request_id}") + self.sys_log.info(f"Connection authorised, session_id: {session_id}") + self._create_remote_connection( + connection_id=connection_request_id, + connection_request_id=payload["connection_request_id"], + session_id=session_id, + ) + payload = { + "type": "login_success", + "username": username, + "password": password, + "connection_request_id": connection_request_id, + } + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_port=self.port, session_id=session_id + ) + elif payload["type"] == "login_success": + self.sys_log.info(f"Login was successful! session_id is:{session_id}") + connection_request_id = payload["connection_request_id"] + self._create_remote_connection( + connection_id=connection_request_id, + session_id=session_id, + connection_request_id=connection_request_id, + ) - if not isinstance(payload, SSHPacket): - return False + elif payload["type"] == "disconnect": + connection_id = payload["connection_id"] + self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server") + self._disconnect(payload["connection_id"]) - if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.warning("Cannot process message as not running") - return False - - if payload.connection_message == SSHConnectionMessage.SSH_MSG_CHANNEL_CLOSE: - # Close the channel - connection_id = kwargs["connection_id"] - dest_ip_address = kwargs["dest_ip_address"] - self.disconnect(dest_ip_address=dest_ip_address) - self.sys_log.debug(f"Disconnecting {connection_id}") - - elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: - return self._process_remote_login(payload=payload, session_id=session_id) - - elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: - self.sys_log.info(f"Login Successful, connection ID is {payload.connection_uuid}") - return True - - elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: - return self.execute(command=payload.payload) - - else: - self.sys_log.warning("Encounter unexpected message type, rejecting connection") - return False + if isinstance(payload, list): + # A request? For me? + self.execute(payload) return True - def execute(self, command: List[Any]) -> RequestResponse: - """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) - def _disconnect(self, connection_uuid: str) -> bool: """Disconnect from the remote. @@ -309,30 +317,16 @@ class Terminal(Service): self.sys_log.warning("No remote connection present") return False - dest_ip_address = self._connections[connection_uuid].dest_ip_address + session_id = self._connections[connection_uuid].session_id self._connections.pop(connection_uuid) software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": connection_uuid}, - dest_ip_address=dest_ip_address, - dest_port=self.port, + payload={"type": "disconnect", "connection_id": connection_uuid}, dest_port=self.port, session_id=session_id ) self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}") return True - def disconnect(self, connection_uuid: Optional[str]) -> bool: - """Disconnect the terminal. - - If no connection id has been supplied, disconnects the first connection. - :param connection_uuid: Connection ID that we want to disconnect. - :return: True if successful, False otherwise. - """ - if not connection_uuid: - connection_uuid = next(iter(self._connections)) - - return self._disconnect(connection_uuid=connection_uuid) - def send( self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None ) -> bool: @@ -345,6 +339,7 @@ class Terminal(Service): if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!") return False + self.sys_log.debug(f"Sending payload: {payload}") return super().send( payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index d4592228..2f093dae 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -16,7 +16,7 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState -from primaite.simulator.system.services.terminal.terminal import Terminal +from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal from primaite.simulator.system.services.web_server.web_server import WebServer @@ -87,8 +87,6 @@ def test_terminal_send(basic_network): payload="Test_Payload", transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, - sender_ip_address=computer_a.network_interface[1].ip_address, - target_ip_address=computer_b.network_interface[1].ip_address, ) assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) @@ -106,11 +104,13 @@ def test_terminal_receive(basic_network): payload=["file_system", "create", "folder", folder_name], transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, - sender_ip_address=computer_a.network_interface[1].ip_address, - target_ip_address=computer_b.network_interface[1].ip_address, ) - assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) + term_a_on_node_b: RemoteTerminalConnection = terminal_a.login( + username="username", password="password", ip_address="192.168.0.11" + ) + + term_a_on_node_b.execute(["file_system", "create", "folder", folder_name]) # Assert that the Folder has been correctly created assert computer_b.file_system.get_folder(folder_name) @@ -127,11 +127,13 @@ def test_terminal_install(basic_network): payload=["software_manager", "application", "install", "RansomwareScript"], transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, - sender_ip_address=computer_a.network_interface[1].ip_address, - target_ip_address=computer_b.network_interface[1].ip_address, ) - terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) + term_a_on_node_b: RemoteTerminalConnection = terminal_a.login( + username="username", password="password", ip_address="192.168.0.11" + ) + + term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"]) assert computer_b.software_manager.software.get("RansomwareScript") @@ -145,29 +147,30 @@ def test_terminal_fail_when_closed(basic_network): terminal.operating_state = ServiceOperatingState.STOPPED - assert ( - terminal.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) - is False + assert not terminal.login( + username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address ) def test_terminal_disconnect(basic_network): - """Terminal should set is_connected to false on disconnect""" + """Test Terminal disconnects""" network: Network = basic_network computer_a: Computer = network.get_node_by_hostname("node_a") terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") computer_b: Computer = network.get_node_by_hostname("node_b") terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") - assert terminal_a.is_connected is False + assert len(terminal_b._connections) == 0 - terminal_a.login(username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address) + term_a_on_term_b = terminal_a.login( + username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address + ) - assert terminal_a.is_connected is True + assert len(terminal_b._connections) == 1 - terminal_a.disconnect(dest_ip_address=computer_b.network_interface[1].ip_address) + term_a_on_term_b.disconnect() - assert terminal_a.is_connected is False + assert len(terminal_b._connections) == 0 def test_terminal_ignores_when_off(basic_network): @@ -178,21 +181,13 @@ def test_terminal_ignores_when_off(basic_network): computer_b: Computer = network.get_node_by_hostname("node_b") - terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") # login to computer_b - - assert terminal_a.is_connected is True + term_a_on_term_b: RemoteTerminalConnection = terminal_a.login( + username="admin", password="Admin123!", ip_address="192.168.0.11" + ) # login to computer_b terminal_a.operating_state = ServiceOperatingState.STOPPED - payload: SSHPacket = SSHPacket( - payload="Test_Payload", - transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, - connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, - sender_ip_address=computer_a.network_interface[1].ip_address, - target_ip_address="192.168.0.11", - ) - - assert not terminal_a.send(payload=payload, dest_ip_address="192.168.0.11") + assert not term_a_on_term_b.execute(["software_manager", "application", "install", "RansomwareScript"]) def test_network_simulation(basic_network): From 814663cf2c2ab4136efe2572aa08d7b756d899ea Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 5 Aug 2024 10:04:23 +0100 Subject: [PATCH 75/95] #2706 - Terminal now installs on a Router --- src/primaite/simulator/network/hardware/base.py | 6 ++++++ .../simulator/network/hardware/nodes/host/host_node.py | 2 +- .../simulator/network/hardware/nodes/network/router.py | 2 ++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 4994e7d3..9230dd47 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -30,6 +30,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.processes.process import Process from primaite.simulator.system.services.service import Service +from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.simulator.system.software import IOSoftware, Software from primaite.utils.converters import convert_dict_enum_keys_to_enum_values from primaite.utils.validators import IPV4Address @@ -1541,6 +1542,11 @@ class Node(SimComponent): """The Nodes User Session Manager.""" return self.software_manager.software.get("UserSessionManager") # noqa + @property + def terminal(self) -> Optional[Terminal]: + """The Nodes Terminal.""" + return self.software_manager.software.get("Terminal") + def local_login(self, username: str, password: str) -> Optional[str]: """ Attempt to log in to the node uas a local user. diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 7393490b..c197d30b 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -314,9 +314,9 @@ class HostNode(Node): "NTPClient": NTPClient, "WebBrowser": WebBrowser, "NMAP": NMAP, - "Terminal": Terminal, "UserSessionManager": UserSessionManager, "UserManager": UserManager, + "Terminal": Terminal, } """List of system software that is automatically installed on nodes.""" diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 42821120..ceb91695 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -24,6 +24,7 @@ from primaite.simulator.system.core.session_manager import SessionManager from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.services.arp.arp import ARP from primaite.simulator.system.services.icmp.icmp import ICMP +from primaite.simulator.system.services.terminal.terminal import Terminal from primaite.utils.validators import IPV4Address @@ -1203,6 +1204,7 @@ class Router(NetworkNode): SYSTEM_SOFTWARE: ClassVar[Dict] = { "UserSessionManager": UserSessionManager, "UserManager": UserManager, + "Terminal": Terminal, } num_ports: int From 2e4a1c37d1708ba7d01e3c16005f81938e1f9796 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 10:34:06 +0100 Subject: [PATCH 76/95] #2777: Pre-commit fixes to test --- .../game_layer/test_RNG_seed.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index c1bb7bb0..0c6d567d 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -1,43 +1,50 @@ -from primaite.config.load import data_manipulation_config_path -from primaite.session.environment import PrimaiteGymEnv -from primaite.game.agent.interface import AgentHistoryItem -import yaml +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from pprint import pprint + import pytest +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.agent.interface import AgentHistoryItem +from primaite.session.environment import PrimaiteGymEnv + @pytest.fixture() def create_env(): - with open(data_manipulation_config_path(), 'r') as f: + with open(data_manipulation_config_path(), "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(env_config = cfg) + env = PrimaiteGymEnv(env_config=cfg) return env + def test_rng_seed_set(create_env): + """Test with RNG seed set.""" env = create_env env.reset(seed=3) for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] env.reset(seed=3) for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + + assert a == b - assert a==b def test_rng_seed_unset(create_env): + """Test with no RNG seed.""" env = create_env env.reset() for i in range(100): env.step(0) - a = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] + a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] env.reset() for i in range(100): env.step(0) - b = [item.timestep for item in env.game.agents['client_2_green_user'].history if item.action!="DONOTHING"] - - assert a!=b + b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] + assert a != b From ca8e56873440eaffe09db6c037c434d478c91029 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 5 Aug 2024 10:58:23 +0100 Subject: [PATCH 77/95] #2706 - Additional tests to check terminal login to/from networknodes. Redo of test to check that a router will block SSH traffic if no ACL rule. --- .../_system/_services/test_terminal.py | 189 +++++++++++------- 1 file changed, 117 insertions(+), 72 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 2f093dae..5010cd8f 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -10,6 +10,7 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -45,6 +46,72 @@ def basic_network() -> Network: return network +@pytest.fixture(scope="function") +def wireless_wan_network(): + network = Network() + + # Configure PC A + pc_a = Computer( + hostname="pc_a", + ip_address="192.168.0.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.0.1", + start_up_duration=0, + ) + pc_a.power_on() + network.add_node(pc_a) + + # Configure Router 1 + router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace) + router_1.power_on() + network.add_node(router_1) + + # Configure the connection between PC A and Router 1 port 2 + router_1.configure_router_interface("192.168.0.1", "255.255.255.0") + network.connect(pc_a.network_interface[1], router_1.network_interface[2]) + + # Configure Router 1 ACLs + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + + # Configure PC B + pc_b = Computer( + hostname="pc_b", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.2.1", + start_up_duration=0, + ) + pc_b.power_on() + network.add_node(pc_b) + + # Configure Router 2 + router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace) + router_2.power_on() + network.add_node(router_2) + + # Configure the connection between PC B and Router 2 port 2 + router_2.configure_router_interface("192.168.2.1", "255.255.255.0") + network.connect(pc_b.network_interface[1], router_2.network_interface[2]) + + # Configure Router 2 ACLs + + # Configure the wireless connection between Router 1 port 1 and Router 2 port 1 + router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0") + router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0") + + router_1.route_table.add_route( + address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2" + ) + + # Configure Route from Router 2 to PC A subnet + router_2.route_table.add_route( + address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1" + ) + + return pc_a, pc_b, router_1, router_2 + + @pytest.fixture def game_and_agent_fixture(game_and_agent): """Create a game with a simple agent that can be controlled by the tests.""" @@ -190,86 +257,64 @@ def test_terminal_ignores_when_off(basic_network): assert not term_a_on_term_b.execute(["software_manager", "application", "install", "RansomwareScript"]) -def test_network_simulation(basic_network): - # 0: Pull out the network - network = basic_network +def test_computer_remote_login_to_router(wireless_wan_network): + """Test to confirm that a computer can SSH into a router.""" + pc_a, pc_b, router_1, router_2 = wireless_wan_network - # 1: Set up network hardware - # 1.1: Configure the router - router = Router(hostname="router", num_ports=3, start_up_duration=0) - router.power_on() - router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") - router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) - # 1.2: Create and connect switches - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) - switch_1.power_on() - network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) - router.enable_port(1) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) - switch_2.power_on() - network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) - router.enable_port(2) + pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") + pc_b_terminal: Terminal = pc_b.software_manager.software.get("Terminal") - # 1.3: Create and connect computer - client_1 = Computer( - hostname="client_1", - ip_address="10.0.1.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.1.1", - start_up_duration=0, - ) - client_1.power_on() - network.connect( - endpoint_a=client_1.network_interface[1], - endpoint_b=switch_1.network_interface[1], - ) + router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") + router_2_terminal: Terminal = router_2.software_manager.software.get("Terminal") - client_2 = Computer( - hostname="client_2", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - ) - client_2.power_on() - network.connect(endpoint_a=client_2.network_interface[1], endpoint_b=switch_2.network_interface[1]) + assert len(pc_a_terminal._connections) == 0 - # 1.4: Create and connect servers - server_1 = Server( - hostname="server_1", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) - server_1.power_on() - network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) + pc_a_on_router_1 = pc_a_terminal.login(username="username", password="password", ip_address="192.168.1.1") - server_2 = Server( - hostname="server_2", - ip_address="10.0.2.3", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) - server_2.power_on() - network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) + assert len(pc_a_terminal._connections) == 1 - # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.ICMP, position=23) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.DNS, dst_port=Port.DNS, position=1) - router.acl.add_rule(action=ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + payload = ["software_manager", "application", "install", "RansomwareScript"] - # 3: Install server software - server_1.software_manager.install(DNSServer) - dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa - dns_service.dns_register("www.example.com", server_2.network_interface[1].ip_address) - server_2.software_manager.install(WebServer) + pc_a_on_router_1.execute(payload) - # 3.1: Ensure that the dns clients are configured correctly - client_1.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address - server_2.software_manager.software.get("DNSClient").dns_server = server_1.network_interface[1].ip_address + assert router_1.software_manager.software.get("RansomwareScript") - terminal_1: Terminal = client_1.software_manager.software.get("Terminal") - assert terminal_1.login(username="admin", password="Admin123!", ip_address="10.0.2.2") is False +def test_router_remote_login_to_computer(wireless_wan_network): + """Test to confirm that a router can ssh into a computer.""" + pc_a, pc_b, router_1, router_2 = wireless_wan_network + + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + + pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") + pc_b_terminal: Terminal = pc_b.software_manager.software.get("Terminal") + + router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") + router_2_terminal: Terminal = router_2.software_manager.software.get("Terminal") + + assert len(router_1_terminal._connections) == 0 + + router_1_on_pc_a = router_1_terminal.login(username="username", password="password", ip_address="192.168.0.2") + + assert len(router_1_terminal._connections) == 1 + + payload = ["software_manager", "application", "install", "RansomwareScript"] + + router_1_on_pc_a.execute(payload) + + assert pc_a.software_manager.software.get("RansomwareScript") + + +def test_router_blocks_SSH_traffic(wireless_wan_network): + """Test to check that router will block SSH traffic if no ACL rule.""" + pc_a, _, _, router_2 = wireless_wan_network + + pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") + + assert len(pc_a_terminal._connections) == 0 + + pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2") + + assert len(pc_a_terminal._connections) == 0 From 7d7117e6246d96a46bae4a1a0c6c619c219a44b5 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 11:13:32 +0100 Subject: [PATCH 78/95] #2777: Merge with dev --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68745913..c52f4678 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML - Agent logging for agents' internal decision logic - Action masking in all PrimAITE environments -- **Random Number Generator Seeding**: Added support for specifying a random number seed in the config file. +- Random Number Generator Seeding by specifying a random number seed in the config file. ### Changed - Application registry was moved to the `Application` class and now updates automatically when Application is subclassed - Databases can no longer respond to request while performing a backup From 972b0b9712e7f3f95da5a7e28b5e7c05289b4817 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 5 Aug 2024 11:19:27 +0100 Subject: [PATCH 79/95] #2706 - Added another test demonstrating an SSH connection across a network. Actioned some review comments and a minor change to other ACL Terminal tests --- .../_system/_services/test_terminal.py | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 5010cd8f..794e88bf 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -29,7 +29,7 @@ def terminal_on_computer() -> Tuple[Terminal, Computer]: computer.power_on() terminal: Terminal = computer.software_manager.software.get("Terminal") - return [terminal, computer] + return terminal, computer @pytest.fixture(scope="function") @@ -74,6 +74,9 @@ def wireless_wan_network(): router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + # add ACL rule to allow SSH traffic + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + # Configure PC B pc_b = Computer( hostname="pc_b", @@ -120,7 +123,7 @@ def game_and_agent_fixture(game_and_agent): client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") client_1.start_up_duration = 3 - return (game, agent) + return game, agent def test_terminal_creation(terminal_on_computer): @@ -259,15 +262,9 @@ def test_terminal_ignores_when_off(basic_network): def test_computer_remote_login_to_router(wireless_wan_network): """Test to confirm that a computer can SSH into a router.""" - pc_a, pc_b, router_1, router_2 = wireless_wan_network - - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + pc_a, _, router_1, _ = wireless_wan_network pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") - pc_b_terminal: Terminal = pc_b.software_manager.software.get("Terminal") - - router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") - router_2_terminal: Terminal = router_2.software_manager.software.get("Terminal") assert len(pc_a_terminal._connections) == 0 @@ -284,15 +281,11 @@ def test_computer_remote_login_to_router(wireless_wan_network): def test_router_remote_login_to_computer(wireless_wan_network): """Test to confirm that a router can ssh into a computer.""" - pc_a, pc_b, router_1, router_2 = wireless_wan_network + pc_a, _, router_1, _ = wireless_wan_network - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) - - pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") - pc_b_terminal: Terminal = pc_b.software_manager.software.get("Terminal") + router_1: Router = router_1 router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") - router_2_terminal: Terminal = router_2.software_manager.software.get("Terminal") assert len(router_1_terminal._connections) == 0 @@ -309,7 +302,12 @@ def test_router_remote_login_to_computer(wireless_wan_network): def test_router_blocks_SSH_traffic(wireless_wan_network): """Test to check that router will block SSH traffic if no ACL rule.""" - pc_a, _, _, router_2 = wireless_wan_network + pc_a, _, router_1, _ = wireless_wan_network + + router_1: Router = router_1 + + # Remove rule that allows SSH traffic. + router_1.acl.remove_rule(position=21) pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal") @@ -318,3 +316,19 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2") assert len(pc_a_terminal._connections) == 0 + + +def test_SSH_across_network(wireless_wan_network): + """Test to show ability to SSH across a network.""" + pc_a, pc_b, router_1, router_2 = wireless_wan_network + + terminal_a: Terminal = pc_a.software_manager.software.get("Terminal") + terminal_b: Terminal = pc_b.software_manager.software.get("Terminal") + + router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21) + + assert len(terminal_a._connections) == 0 + + terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2") + + assert len(terminal_a._connections) == 1 From 966542c2ca1b00d128594ae4afdd638d45160972 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 15:08:31 +0100 Subject: [PATCH 80/95] #2777: Add determinism to torch backends when seed set. --- src/primaite/session/environment.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 359932c7..a12d2eb7 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -44,6 +44,10 @@ def set_random_seed(seed: int) -> Union[None, int]: # if torch not installed don't set random seed. if sys.modules["torch"]: th.manual_seed(seed) + + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + return seed From d059ddceaba77ac60ed9f24b4120e3375bfc384c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 15:11:57 +0100 Subject: [PATCH 81/95] #2777: Remove debug print statement --- src/primaite/game/agent/scripted_agents/probabilistic_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index ce1da3f2..ab2e69ef 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -63,7 +63,6 @@ class ProbabilisticAgent(AbstractScriptedAgent): self.settings = ProbabilisticAgent.Settings(**settings) rng_seed = np.random.randint(0, 65535) self.rng = np.random.default_rng(rng_seed) - print(f"Probabilistic Agent - rng_seed: {rng_seed}") # convert probabilities from self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) From 4fe9753fcf5f80b776ebcb893492eca56e566556 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 5 Aug 2024 15:44:52 +0100 Subject: [PATCH 82/95] #2706 - Updated terminal.receive() to work with SSHPacket class, fixed some tests and updated RemoteTerminalConnection to hold Source_IP for easier reading --- .../system/services/terminal.rst | 16 +- .../notebooks/Terminal-Processing.ipynb | 103 ++++++---- .../system/services/terminal/terminal.py | 186 +++++++++++++----- tests/integration_tests/system/test_nmap.py | 2 +- .../_system/_services/test_terminal.py | 18 +- 5 files changed, 222 insertions(+), 103 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index 4b02a6db..37872b5b 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -19,7 +19,6 @@ installed on Nodes when they are instantiated. Key capabilities ================ - - Authenticates User connection by maintaining an active User account. - Ensures packets are matched to an existing session - Simulates common Terminal processes/commands. - Leverages the Service base class for install/uninstall, status tracking etc. @@ -27,21 +26,18 @@ Key capabilities Usage ===== - - Pre-Installs on any `HostNode` component. See the below code example of how to access the terminal. - - Terminal Clients connect, execute commands and disconnect from remote components. + - Pre-Installs on any `Node` (component with the exception of `Switches`). + - Terminal Clients connect, execute commands and disconnect from remote nodes. - Ensures that users are logged in to the component before executing any commands. - Service runs on SSH port 22 by default. Implementation ============== -The terminal takes inspiration from the `Database Client` and `Database Service` classes, and leverages the `UserSessionManager` -to provide User Credential authentication when receiving/processing commands. - -Terminal acts as the interface between the user/component and both the Session and Requests Managers, facilitating -the passing of requests to both. - -A more detailed example of how to use the Terminal class can be found in the Terminal-Processing jupyter notebook. + - Manages remote connections in a dictionary by session ID. + - Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate. + - Extends Service class. + - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. Python """""" diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 77be3822..30b1a5e7 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -15,7 +15,7 @@ "source": [ "This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n", "\n", - "By default, the Terminal will come pre-installed on any simulation component which inherits from `HostNode` (Computer, Server, Printer), and simulates the Secure Shell (SSH) protocol as the communication method." + "The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). " ] }, { @@ -27,15 +27,9 @@ "from primaite.simulator.system.services.terminal.terminal import Terminal\n", "from primaite.simulator.network.container import Network\n", "from primaite.simulator.network.hardware.nodes.host.computer import Computer\n", - "from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n", + "from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n", + "\n", "def basic_network() -> Network:\n", " \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n", " network = Network()\n", @@ -51,9 +45,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The terminal can be accessed from a `HostNode` via the `software_manager` as demonstrated below. \n", + "The terminal can be accessed from a `Node` via the `software_manager` as demonstrated below. \n", "\n", - "In the example, we have a basic network consisting of two computers " + "In the example, we have a basic network consisting of two computers, connected to form a basic network." ] }, { @@ -66,15 +60,17 @@ "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", - "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")\n" + "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are logging in to the 'admin' account on `node_b`. \n", - "If you are not logged in, any commands sent will be rejected." + "To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are remotely logging in to the 'admin' account on `node_b`, from `node_a`. \n", + "If you are not logged in, any commands sent will be rejected by the remote.\n", + "\n", + "Remote Logins return a RemoteTerminalConnection object, which can be used for sending commands to the remote node. " ] }, { @@ -84,10 +80,14 @@ "outputs": [], "source": [ "# Login to the remote (node_b) from local (node_a)\n", - "from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n", - "\n", - "\n", - "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=computer_b.network_interface[1].ip_address)" + "term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=\"192.168.0.11\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can view all active connections to a terminal through use of the `show()` method" ] }, { @@ -96,7 +96,14 @@ "metadata": {}, "outputs": [], "source": [ - "computer_b.software_manager.show()" + "terminal_b.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The new connection object allows us to forward commands to be executed on the target node. The example below demonstrates how you can remotely install an application on the target node." ] }, { @@ -105,7 +112,6 @@ "metadata": {}, "outputs": [], "source": [ - "print(type(term_a_term_b_remote_connection))\n", "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])" ] }, @@ -122,7 +128,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can view all remote connections to a terminal through use of the `show()` method" + "The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to create a downloads folder. \n" ] }, { @@ -131,23 +137,11 @@ "metadata": {}, "outputs": [], "source": [ - "terminal_b.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The Terminal can be used to send requests to install new software. The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to install the `RansomwareScript` application. \n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The below example shows how you can send a command via the terminal to create a folder on the target Node.\n", + "# Display the current state of the file system on computer_b\n", + "computer_b.file_system.show()\n", "\n", - "Here, we send a command to `computer_b` to create a new folder titled \"Downloads\"." + "# Send command\n", + "term_a_term_b_remote_connection.execute([\"file_system\", \"create\", \"folder\", \"downloads\"])" ] }, { @@ -156,6 +150,39 @@ "source": [ "The resultant call to `computer_b.file_system.show()` shows that the new folder has been created." ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "computer_b.file_system.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When finished, the connection can be closed by calling the `disconnect` function of the Remote Client object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display active connection\n", + "terminal_a.show()\n", + "terminal_b.show()\n", + "\n", + "term_a_term_b_remote_connection.disconnect()\n", + "\n", + "terminal_a.show()\n", + "\n", + "terminal_b.show()" + ] } ], "metadata": { diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index b7bc5287..0f8e180e 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable @@ -10,7 +10,12 @@ from pydantic import BaseModel from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType -from primaite.simulator.network.protocols.ssh import SSHPacket +from primaite.simulator.network.protocols.ssh import ( + SSHConnectionMessage, + SSHPacket, + SSHTransportMessage, + SSHUserCredentials, +) from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager @@ -33,6 +38,9 @@ class TerminalClientConnection(BaseModel): connection_uuid: str = None """Connection UUID""" + connection_request_id: str = None + """Connection request ID""" + @property def client(self) -> Optional[Terminal]: """The Terminal that holds this connection.""" @@ -51,6 +59,9 @@ class RemoteTerminalConnection(TerminalClientConnection): """ + source_ip: IPv4Address + """Source IP of Connection""" + def execute(self, command: Any) -> bool: """Execute a given command on the remote Terminal.""" if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING: @@ -88,13 +99,13 @@ class Terminal(Service): :param markdown: Whether to display the table in Markdown format or not. Default is `False`. """ - table = PrettyTable(["Connection ID", "Session_ID"]) + table = PrettyTable(["Connection ID", "Connection request ID", "Source IP"]) if markdown: table.set_style(MARKDOWN) table.align = "l" table.title = f"{self.sys_log.hostname} {self.name} Connections" for connection_id, connection in self._connections.items(): - table.add_row([connection_id, connection.session_id]) + table.add_row([connection_id, connection.connection_request_id, connection.source_ip]) print(table.get_string(sortby="Connection ID")) def _init_request_manager(self) -> RequestManager: @@ -130,7 +141,7 @@ class Terminal(Service): connection_uuid = request[0] # TODO: Uncomment this when UserSessionManager merged. # self.parent.UserSessionManager.logoff(connection_uuid) - self.disconnect(connection_uuid) + self._disconnect(connection_uuid) return RequestResponse(status="success", data={}) @@ -157,8 +168,13 @@ class Terminal(Service): """Execute a passed ssh command via the request manager.""" return self.parent.apply_request(command) - def _create_local_connection(self, connection_uuid: str, session_id: str) -> RemoteTerminalConnection: - """Create a new connection object and amend to list of active connections.""" + def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: + """Create a new connection object and amend to list of active connections. + + :param connection_uuid: Connection ID of the new local connection + :param session_id: Session ID of the new local connection + :return: TerminalClientConnection object + """ new_connection = TerminalClientConnection( parent_terminal=self, connection_uuid=connection_uuid, @@ -172,7 +188,17 @@ class Terminal(Service): def login( self, username: str, password: str, ip_address: Optional[IPv4Address] = None ) -> Optional[TerminalClientConnection]: - """Login to the terminal. Will attempt a remote login if ip_address is given, else local.""" + """Login to the terminal. Will attempt a remote login if ip_address is given, else local. + + :param: username: Username used to connect to the remote node. + :type: username: str + + :param: password: Password used to connect to the remote node + :type: password: str + + :param: ip_address: Target Node IP address for login attempt. If None, login is assumed local. + :type: ip_address: Optional[IPv4Address] + """ if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning("Cannot login as service is not running.") return None @@ -199,8 +225,10 @@ class Terminal(Service): if connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") # Add new local session to list of connections - self._create_local_connection(connection_uuid=connection_uuid, session_id="") - return TerminalClientConnection(parent_terminal=self, session_id="", connection_uuid=connection_uuid) + self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection") + return TerminalClientConnection( + parent_terminal=self, session_id="Local_Connection", connection_uuid=connection_uuid + ) else: self.sys_log.warning("Login failed, incorrect Username or Password") return None @@ -217,7 +245,26 @@ class Terminal(Service): connection_request_id: str, is_reattempt: bool = False, ) -> Optional[RemoteTerminalConnection]: - """Process a remote login attempt.""" + """Send a remote login attempt and connect to Node. + + :param: username: Username used to connect to the remote node. + :type: username: str + + :param: password: Password used to connect to the remote node + :type: password: str + + :param: ip_address: Target Node IP address for login attempt. + :type: ip_address: IPv4Address + + :param: connection_request_id: Connection Request ID + :type: connection_request_id: str + + :param: is_reattempt: True if the request has been reattempted. Default False. + :type: is_reattempt: Optional[bool] + + :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. + + """ self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}") if is_reattempt: valid_connection = self._check_client_connection(connection_id=connection_request_id) @@ -229,12 +276,24 @@ class Terminal(Service): self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") return None - payload = { + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password) + + payload_contents = { "type": "login_request", "username": username, "password": password, "connection_request_id": connection_request_id, } + + payload: SSHPacket = SSHPacket( + payload=payload_contents, + transport_message=transport_message, + connection_message=connection_message, + user_account=user_details, + ) + software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload=payload, dest_ip_address=ip_address, dest_port=self.port @@ -247,15 +306,28 @@ class Terminal(Service): connection_request_id=connection_request_id, ) - def _create_remote_connection(self, connection_id: str, connection_request_id: str, session_id: str) -> None: - """Create a new TerminalClientConnection Object.""" + def _create_remote_connection( + self, connection_id: str, connection_request_id: str, session_id: str, source_ip: str + ) -> None: + """Create a new TerminalClientConnection Object. + + :param: connection_request_id: Connection Request ID + :type: connection_request_id: str + + :param: session_id: Session ID of connection. + :type: session_id: str + """ client_connection = RemoteTerminalConnection( - parent_terminal=self, session_id=session_id, connection_uuid=connection_id + parent_terminal=self, + session_id=session_id, + connection_uuid=connection_id, + source_ip=source_ip, + connection_request_id=connection_request_id, ) self._connections[connection_id] = client_connection self._client_connection_requests[connection_request_id] = client_connection - def receive(self, session_id: str, payload: Any, **kwargs) -> bool: + def receive(self, session_id: str, payload: Union[SSHPacket, Dict, List], **kwargs) -> bool: """ Receive a payload from the Software Manager. @@ -263,42 +335,62 @@ class Terminal(Service): :param session_id: The session id the payload relates to. :return: True. """ - self.sys_log.info(f"Received payload: {payload}") - if isinstance(payload, dict) and payload.get("type"): - if payload["type"] == "login_request": - # add connection - connection_request_id = payload["connection_request_id"] - username = payload["username"] - password = payload["password"] - print(f"Connection ID is: {connection_request_id}") - self.sys_log.info(f"Connection authorised, session_id: {session_id}") + source_ip = kwargs["from_network_interface"].ip_address + self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}") + if isinstance(payload, SSHPacket): + if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: + # validate & add connection + # TODO: uncomment this as part of 2781 + # connection_id = self.parent.UserSessionManager.login(username=username, password=password) + connection_id = str(uuid4()) + if connection_id: + connection_request_id = payload.connection_request_uuid + username = payload.user_account.username + password = payload.user_account.password + print(f"Connection ID is: {connection_request_id}") + self.sys_log.info(f"Connection authorised, session_id: {session_id}") + self._create_remote_connection( + connection_id=connection_id, + connection_request_id=connection_request_id, + session_id=session_id, + source_ip=source_ip, + ) + + transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS + connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + + payload_contents = { + "type": "login_success", + "username": username, + "password": password, + "connection_request_id": connection_request_id, + "connection_id": connection_id, + } + payload: SSHPacket = SSHPacket( + payload=payload_contents, + transport_message=transport_message, + connection_message=connection_message, + connection_request_uuid=connection_request_id, + connection_uuid=connection_id, + ) + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload=payload, dest_port=self.port, session_id=session_id + ) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS: + self.sys_log.info("Login Successful") self._create_remote_connection( - connection_id=connection_request_id, - connection_request_id=payload["connection_request_id"], + connection_id=payload.connection_uuid, + connection_request_id=payload.connection_request_uuid, session_id=session_id, - ) - payload = { - "type": "login_success", - "username": username, - "password": password, - "connection_request_id": connection_request_id, - } - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload=payload, dest_port=self.port, session_id=session_id - ) - elif payload["type"] == "login_success": - self.sys_log.info(f"Login was successful! session_id is:{session_id}") - connection_request_id = payload["connection_request_id"] - self._create_remote_connection( - connection_id=connection_request_id, - session_id=session_id, - connection_request_id=connection_request_id, + source_ip=source_ip, ) - elif payload["type"] == "disconnect": + if isinstance(payload, dict) and payload.get("type"): + if payload["type"] == "disconnect": connection_id = payload["connection_id"] - self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server") + self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.") self._disconnect(payload["connection_id"]) if isinstance(payload, list): diff --git a/tests/integration_tests/system/test_nmap.py b/tests/integration_tests/system/test_nmap.py index 08251d71..2b8691cc 100644 --- a/tests/integration_tests/system/test_nmap.py +++ b/tests/integration_tests/system/test_nmap.py @@ -107,7 +107,7 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network): expected_result = { IPv4Address("192.168.10.1"): {IPProtocol.UDP: [Port.ARP]}, IPv4Address("192.168.10.22"): { - IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS, Port.SSH], + IPProtocol.TCP: [Port.HTTP, Port.FTP, Port.DNS], IPProtocol.UDP: [Port.ARP, Port.NTP], }, } diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 794e88bf..c86d6466 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -1,5 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import Tuple +from uuid import uuid4 import pytest @@ -11,7 +12,12 @@ from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.protocols.ssh import SSHConnectionMessage, SSHPacket, SSHTransportMessage +from primaite.simulator.network.protocols.ssh import ( + SSHConnectionMessage, + SSHPacket, + SSHTransportMessage, + SSHUserCredentials, +) from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript @@ -155,8 +161,10 @@ def test_terminal_send(basic_network): payload: SSHPacket = SSHPacket( payload="Test_Payload", - transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST, - connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN, + transport_message=SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + user_account=SSHUserCredentials(username="username", password="password"), + connection_request_uuid=str(uuid4()), ) assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address) @@ -283,8 +291,6 @@ def test_router_remote_login_to_computer(wireless_wan_network): """Test to confirm that a router can ssh into a computer.""" pc_a, _, router_1, _ = wireless_wan_network - router_1: Router = router_1 - router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal") assert len(router_1_terminal._connections) == 0 @@ -304,8 +310,6 @@ def test_router_blocks_SSH_traffic(wireless_wan_network): """Test to check that router will block SSH traffic if no ACL rule.""" pc_a, _, router_1, _ = wireless_wan_network - router_1: Router = router_1 - # Remove rule that allows SSH traffic. router_1.acl.remove_rule(position=21) From 63a689d94afa88ada4085984deaa2db8235e55a1 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 5 Aug 2024 16:25:35 +0100 Subject: [PATCH 83/95] #2706 - correcting test failures --- src/primaite/simulator/system/services/terminal/terminal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 0f8e180e..274353ed 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -292,6 +292,7 @@ class Terminal(Service): transport_message=transport_message, connection_message=connection_message, user_account=user_details, + connection_request_uuid=connection_request_id, ) software_manager: SoftwareManager = self.software_manager From 3253dd80547125635c8c13693689f15bbafc6e67 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 16:27:54 +0100 Subject: [PATCH 84/95] #2777: Update test --- .../_primaite/_game/_agent/test_probabilistic_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index f3b3c6eb..ec18f1fb 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -62,7 +62,6 @@ def test_probabilistic_agent(): reward_function=reward_function, settings={ "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, - "random_seed": 120, }, ) From 3441dd25092aff65c7c9f5e9e0d11855f7bad8d7 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 5 Aug 2024 17:45:01 +0100 Subject: [PATCH 85/95] #2777: Code review changes. --- CHANGELOG.md | 4 ++-- .../game/agent/scripted_agents/probabilistic_agent.py | 2 +- src/primaite/session/environment.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c52f4678..8b3cfbb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] - +### Added +- Random Number Generator Seeding by specifying a random number seed in the config file. ### Changed - Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality. @@ -22,7 +23,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML - Agent logging for agents' internal decision logic - Action masking in all PrimAITE environments -- Random Number Generator Seeding by specifying a random number seed in the config file. ### Changed - Application registry was moved to the `Application` class and now updates automatically when Application is subclassed - Databases can no longer respond to request while performing a backup diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index ab2e69ef..cd44644f 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -68,7 +68,7 @@ class ProbabilisticAgent(AbstractScriptedAgent): self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) super().__init__(agent_name, action_space, observation_space, reward_function) - self.logger.info(f"ProbabilisticAgent RNG seed: {rng_seed}") + self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}") def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a12d2eb7..c66663e3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -44,9 +44,8 @@ def set_random_seed(seed: int) -> Union[None, int]: # if torch not installed don't set random seed. if sys.modules["torch"]: th.manual_seed(seed) - - th.backends.cudnn.deterministic = True - th.backends.cudnn.benchmark = False + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False return seed @@ -64,7 +63,7 @@ class PrimaiteGymEnv(gymnasium.Env): super().__init__() self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) """Object that returns a config corresponding to the current episode.""" - self.seed = self.episode_scheduler(0).get("game").get("seed") + self.seed = self.episode_scheduler(0).get("game", {}).get("seed") """Get RNG seed from config file. NB: Must be before game instantiation.""" self.seed = set_random_seed(self.seed) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) From d2011ff32767730e087261f5c4b6c7b7bcf766d6 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 5 Aug 2024 22:23:54 +0100 Subject: [PATCH 86/95] #2811 - Updated syslog messaging around DatabaseClient and DatabaseService connection request and password authentication --- .../system/applications/database_client.py | 50 +++++++++++++------ .../services/database/database_service.py | 18 +++++-- src/primaite/simulator/system/software.py | 6 +-- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 06d22126..e6cfa343 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -2,7 +2,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable @@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel): if self.client and self.is_active: self.client._disconnect(self.connection_id) # noqa + def __str__(self) -> str: + return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})" + + def __repr__(self) -> str: + return str(self) + class DatabaseClient(Application, identifier="DatabaseClient"): """ @@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """Connection ID to the Database Server.""" client_connections: Dict[str, DatabaseClientConnection] = {} """Keep track of active connections to Database Server.""" - _client_connection_requests: Dict[str, Optional[str]] = {} + _client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {} """Dictionary of connection requests to Database Server.""" connected: bool = False """Boolean Value for whether connected to DB Server.""" @@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"): return False return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id) - def _check_client_connection(self, connection_id: str) -> bool: + def _validate_client_connection_request(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" return True if connection_id in self._client_connection_requests else False @@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"): :type: is_reattempt: Optional[bool] """ if is_reattempt: - valid_connection = self._check_client_connection(connection_id=connection_request_id) - if valid_connection: + valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) + if valid_connection_request: database_client_connection = self._client_connection_requests.pop(connection_request_id) - self.sys_log.info( - f"{self.name}: DatabaseClient connection to {server_ip_address} authorised." - f"Connection Request ID was {connection_request_id}." - ) - self.connected = True - self._last_connection_successful = True - return database_client_connection + if isinstance(database_client_connection, DatabaseClientConnection): + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. " + f"Using connection id {database_client_connection}" + ) + self.connected = True + self._last_connection_successful = True + return database_client_connection + else: + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined" + ) + self._last_connection_successful = False + return None else: - self.sys_log.warning( - f"{self.name}: DatabaseClient connection to {server_ip_address} declined." - f"Connection Request ID was {connection_request_id}." + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined " + f"due to unknown client-side connection request id" ) - self._last_connection_successful = False return None + payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id} software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( @@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ if not self._can_perform_action(): return None + connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None + self.sys_log.info( + f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}" + ) + return self._connect( server_ip_address=self.server_ip_address, password=self.server_password, diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 22ae0ff3..74ef51ee 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -191,12 +191,16 @@ class DatabaseService(Service): :return: Response to connection request containing success info. :rtype: Dict[str, Union[int, Dict[str, bool]]] """ + self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}") status_code = 500 # Default internal server error connection_id = None if self.operating_state == ServiceOperatingState.RUNNING: status_code = 503 # service unavailable if self.health_state_actual == SoftwareHealthState.OVERWHELMED: - self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.") + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at " + f"capacity." + ) if self.health_state_actual in [ SoftwareHealthState.GOOD, SoftwareHealthState.FIXING, @@ -208,12 +212,16 @@ class DatabaseService(Service): # try to create connection if not self.add_connection(connection_id=connection_id, session_id=session_id): status_code = 500 - self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined") - else: - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, " + f"returning status code 500" + ) else: status_code = 401 # Unauthorised - self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined") + self.sys_log.info( + f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised " + f"(incorrect password), returning status code 401" + ) else: status_code = 404 # service not found return { diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index 7c27534a..efa8c9b1 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -313,7 +313,7 @@ class IOSoftware(Software): # if over or at capacity, set to overwhelmed if len(self._connections) >= self.max_sessions: self.set_health_state(SoftwareHealthState.OVERWHELMED) - self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.") + self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.") return False else: # if service was previously overwhelmed, set to good because there is enough space for connections @@ -330,11 +330,11 @@ class IOSoftware(Software): "ip_address": session_details.with_ip_address if session_details else None, "time": datetime.now(), } - self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised") + self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised") return True # connection with given id already exists self.sys_log.warning( - f"{self.name}: Connect request for {connection_id=} declined. Connection already exists." + f"{self.name}: Connection request ({connection_id}) declined. Connection already exists." ) return False From 1e64e87798983341dac5aba4b8025a4bc2a075c5 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 09:30:27 +0100 Subject: [PATCH 87/95] #2706 - Actioning Review comments --- .../system/services/terminal.rst | 112 ++++++++++++++++++ .../simulator/network/protocols/ssh.py | 12 +- .../_system/_services/test_terminal.py | 16 +++ 3 files changed, 136 insertions(+), 4 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index 37872b5b..0e362326 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -39,6 +39,12 @@ Implementation - Extends Service class. - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +Usage +===== + +The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node. + Python """""" @@ -59,3 +65,109 @@ Python ) terminal: Terminal = client.software_manager.software.get("Terminal") + +Obtaining Remote Connection +""""""""""""""""""""""""""" + + +.. code-block:: python + + from primaite.simulator.system.services.terminal.terminal import Terminal + from primaite.simulator.network.container import Network + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + + + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + + + term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") + + + +Executing a basic application install command +""""""""""""""""""""""""""""""""" + +.. code-block:: python + + from primaite.simulator.system.services.terminal.terminal import Terminal + from primaite.simulator.network.container import Network + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript + + + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + + + term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") + + term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) + + + +Creating a file on a remote node +"""""""""""""""""""""""""""""""" + +.. code-block:: python + + from primaite.simulator.system.services.terminal.terminal import Terminal + from primaite.simulator.network.container import Network + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript + + + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + + + term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") + + term_a_term_b_remote_connection.execute(["file_system", "create", "folder", "downloads"]) + + +Disconnect from Remote Node + +.. code-block:: python + + from primaite.simulator.system.services.terminal.terminal import Terminal + from primaite.simulator.network.container import Network + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript + + + network = Network() + node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0) + node_a.power_on() + node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0) + node_b.power_on() + network.connect(node_a.network_interface[1], node_b.network_interface[1]) + + terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + + + term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") + + term_a_term_b_remote_connection.disconnect() diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index 7ba629f8..ca9663d8 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -76,10 +76,14 @@ class SSHPacket(DataPacket): user_account: Optional[SSHUserCredentials] = None """User Account Credentials if passed""" - connection_request_uuid: Optional[str] = None # Connection Request uuid. + connection_request_uuid: Optional[str] = None + """Connection Request UUID used when establishing a remote connection""" - connection_uuid: Optional[str] = None # The connection uuid used to validate the session + connection_uuid: Optional[str] = None + """Connection UUID used when validating a remote connection""" - ssh_output: Optional[RequestResponse] = None # The Request Manager's returned RequestResponse + ssh_output: Optional[RequestResponse] = None + """RequestResponse from Request Manager""" - ssh_command: Optional[str] = None # This is the request string + ssh_command: Optional[str] = None + """Request String""" diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index c86d6466..7e98e501 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -336,3 +336,19 @@ def test_SSH_across_network(wireless_wan_network): terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2") assert len(terminal_a._connections) == 1 + + +def test_multiple_remote_terminals_same_node(basic_network): + """Test to check that multiple remote terminals can be spawned by one node.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + assert len(terminal_a._connections) == 0 + + # Spam login requests to terminal. + for attempt in range(10): + remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + + assert len(terminal_a._connections) == 10 From 457395baee922d5ef6d446872f0545d2a8d260f0 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 09:33:41 +0100 Subject: [PATCH 88/95] #2706 - Correcting wording on documentation titles --- .../source/simulation_components/system/services/terminal.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index 0e362326..5097f213 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -66,7 +66,7 @@ Python terminal: Terminal = client.software_manager.software.get("Terminal") -Obtaining Remote Connection +Creating Remote Terminal Connection """"""""""""""""""""""""""" @@ -120,7 +120,7 @@ Executing a basic application install command -Creating a file on a remote node +Creating a folder on a remote node """""""""""""""""""""""""""""""" .. code-block:: python From 89107f2c4bad297d4ab6d00bd0ab4d1e0b41f200 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 10:37:11 +0100 Subject: [PATCH 89/95] #2706 - Type-hint changes following review --- .../simulator/system/services/terminal/terminal.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 274353ed..0c94a565 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -8,7 +8,7 @@ from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.ssh import ( SSHConnectionMessage, @@ -116,27 +116,27 @@ class Terminal(Service): request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), ) - def _login(request: List[Any], context: Any) -> RequestResponse: + def _login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._process_local_login(username=request[0], password=request[1]) if login: return RequestResponse(status="success", data={}) else: return RequestResponse(status="failure", data={}) - def _remote_login(request: List[Any], context: Any) -> RequestResponse: + def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) if login: return RequestResponse(status="success", data={}) else: return RequestResponse(status="failure", data={}) - def _execute_request(request: List[Any], context: Any) -> RequestResponse: + def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" command: str = request[0] self.execute(command) return RequestResponse(status="success", data={}) - def _logoff(request: List[Any]) -> RequestResponse: + def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: """Logoff from connection.""" connection_uuid = request[0] # TODO: Uncomment this when UserSessionManager merged. From 68621f172b18e5fdb183d283f52b5bb49b25b195 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 12:10:14 +0100 Subject: [PATCH 90/95] #2706 - xfail on test_ray_multi_agent_action_masking as this is causing pipeline failures. Bugticket raised as 2812 --- .../action_masking/test_agents_use_action_masks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index 745e280b..4260c605 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -100,6 +100,7 @@ def test_ray_single_agent_action_masking(monkeypatch): monkeypatch.undo() +@pytest.mark.xfail(reason="Fails due to being flaky when run in CI.") def test_ray_multi_agent_action_masking(monkeypatch): """Check that Ray agents never take invalid actions when using MARL.""" with open(MARL_PATH, "r") as f: From df49b3b5bbb8b54190d98cfdee152bbd1be24304 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 14:10:10 +0100 Subject: [PATCH 91/95] #2706 - Actioning Review Comments --- .../system/services/terminal/terminal.py | 73 ++++++++++++------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 0c94a565..0bcec90d 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -1,11 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +from datetime import datetime from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Union from uuid import uuid4 -from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel from primaite.interface.request import RequestFormat, RequestResponse @@ -41,6 +41,21 @@ class TerminalClientConnection(BaseModel): connection_request_id: str = None """Connection request ID""" + time: datetime = None + """Timestammp connection was created.""" + + ip_address: IPv4Address + """Source IP of Connection""" + + def __str__(self) -> str: + return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')" + + def __repr__(self) -> str: + return self.__str__() + + def __getitem__(self, key: Any) -> Any: + return getattr(self, key) + @property def client(self) -> Optional[Terminal]: """The Terminal that holds this connection.""" @@ -59,9 +74,6 @@ class RemoteTerminalConnection(TerminalClientConnection): """ - source_ip: IPv4Address - """Source IP of Connection""" - def execute(self, command: Any) -> bool: """Execute a given command on the remote Terminal.""" if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING: @@ -73,7 +85,7 @@ class RemoteTerminalConnection(TerminalClientConnection): class Terminal(Service): """Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH.""" - _client_connection_requests: Dict[str, Optional[str]] = {} + _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} def __init__(self, **kwargs): kwargs["name"] = "Terminal" @@ -99,14 +111,7 @@ class Terminal(Service): :param markdown: Whether to display the table in Markdown format or not. Default is `False`. """ - table = PrettyTable(["Connection ID", "Connection request ID", "Source IP"]) - if markdown: - table.set_style(MARKDOWN) - table.align = "l" - table.title = f"{self.sys_log.hostname} {self.name} Connections" - for connection_id, connection in self._connections.items(): - table.add_row([connection_id, connection.connection_request_id, connection.source_ip]) - print(table.get_string(sortby="Connection ID")) + self.show_connections(markdown=markdown) def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" @@ -179,6 +184,7 @@ class Terminal(Service): parent_terminal=self, connection_uuid=connection_uuid, session_id=session_id, + time=datetime.now(), ) self._connections[connection_uuid] = new_connection self._client_connection_requests[connection_uuid] = new_connection @@ -224,19 +230,20 @@ class Terminal(Service): connection_uuid = str(uuid4()) if connection_uuid: self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}") - # Add new local session to list of connections - self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection") - return TerminalClientConnection( - parent_terminal=self, session_id="Local_Connection", connection_uuid=connection_uuid - ) + # Add new local session to list of connections and return + return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection") else: self.sys_log.warning("Login failed, incorrect Username or Password") return None - def _check_client_connection(self, connection_id: str) -> bool: + def _validate_client_connection_request(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" return True if connection_id in self._client_connection_requests else False + def _check_client_connection(self, connection_id: str) -> bool: + """Check that client_connection_id is valid.""" + return True if connection_id in self._connections else False + def _send_remote_login( self, username: str, @@ -267,11 +274,15 @@ class Terminal(Service): """ self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}") if is_reattempt: - valid_connection = self._check_client_connection(connection_id=connection_request_id) - if valid_connection: + valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) + if valid_connection_request: remote_terminal_connection = self._client_connection_requests.pop(connection_request_id) - self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") - return remote_terminal_connection + if isinstance(remote_terminal_connection, RemoteTerminalConnection): + self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.") + return remote_terminal_connection + else: + self.sys_log.warning(f"Connection request{connection_request_id} declined") + return None else: self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") return None @@ -322,8 +333,9 @@ class Terminal(Service): parent_terminal=self, session_id=session_id, connection_uuid=connection_id, - source_ip=source_ip, + ip_address=source_ip, connection_request_id=connection_request_id, + time=datetime.now(), ) self._connections[connection_id] = client_connection self._client_connection_requests[connection_request_id] = client_connection @@ -391,8 +403,12 @@ class Terminal(Service): if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": connection_id = payload["connection_id"] - self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.") - self._disconnect(payload["connection_id"]) + valid_id = self._check_client_connection(connection_id) + if valid_id: + self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.") + self._disconnect(payload["connection_id"]) + else: + self.sys_log.info("No Active connection held for received connection ID.") if isinstance(payload, list): # A request? For me? @@ -410,8 +426,9 @@ class Terminal(Service): self.sys_log.warning("No remote connection present") return False - session_id = self._connections[connection_uuid].session_id - self._connections.pop(connection_uuid) + # session_id = self._connections[connection_uuid].session_id + connection: RemoteTerminalConnection = self._connections.pop(connection_uuid) + session_id = connection.session_id software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( From dd7e4661044387408bed49fd0a63a57e4d5c3dd9 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 15:01:53 +0100 Subject: [PATCH 92/95] #2706 - Fixing pipeline failure --- .../action_masking/test_agents_use_action_masks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py index 4260c605..addf6dca 100644 --- a/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py +++ b/tests/e2e_integration_tests/action_masking/test_agents_use_action_masks.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from typing import Dict +import pytest import yaml from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec From de14dfdc485860e5e3fa236114e44e22302fae6e Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 16:22:08 +0100 Subject: [PATCH 93/95] #2706 - Updated Changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index adf24fdc..7ce4bbf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Implemented Terminal service class, providing a generic terminal simulation. + ### Changed - Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality. From d05fd00594e27c70dc7f8be9b3df1beb7e702547 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 19:09:23 +0100 Subject: [PATCH 94/95] #2706 - Resolving an issue that saw disconnected terminal connections still able to send execute commands that were also then processed by the target node. Created a new class: LocalterminalConnection, for local connection objects to terminal. Calling terminal.show() when there is a local connection will have 'Local Connection' as the IP address. Receive and execute will check that the provided connection uuid is valid before actioning any commands. TerminalClientConnection objects now have an is_active flag similar to DatabaseClientConnection. Added a new test to check that terminals will reject commands from disconnected clientconnection objects. --- .../simulator/network/protocols/ssh.py | 2 +- .../system/services/terminal/terminal.py | 104 ++++++++++++++---- .../_system/_services/test_terminal.py | 24 ++++ 3 files changed, 109 insertions(+), 21 deletions(-) diff --git a/src/primaite/simulator/network/protocols/ssh.py b/src/primaite/simulator/network/protocols/ssh.py index ca9663d8..be7f842f 100644 --- a/src/primaite/simulator/network/protocols/ssh.py +++ b/src/primaite/simulator/network/protocols/ssh.py @@ -85,5 +85,5 @@ class SSHPacket(DataPacket): ssh_output: Optional[RequestResponse] = None """RequestResponse from Request Manager""" - ssh_command: Optional[str] = None + ssh_command: Optional[list] = None """Request String""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 0bcec90d..0ebae491 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -1,6 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations +from abc import abstractmethod from datetime import datetime from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Union @@ -42,11 +43,14 @@ class TerminalClientConnection(BaseModel): """Connection request ID""" time: datetime = None - """Timestammp connection was created.""" + """Timestamp connection was created.""" ip_address: IPv4Address """Source IP of Connection""" + is_active: bool = True + """Flag to state whether the connection is active or not""" + def __str__(self) -> str: return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')" @@ -65,6 +69,28 @@ class TerminalClientConnection(BaseModel): """Disconnect the session.""" return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid) + @abstractmethod + def execute(self, command: Any) -> bool: + """Execute a given command.""" + pass + + +class LocalTerminalConnection(TerminalClientConnection): + """ + LocalTerminalConnectionClass. + + This class represents a local terminal when connected. + """ + + ip_address: str = "Local Connection" + + def execute(self, command: Any) -> RequestResponse: + """Execute a given command on local Terminal.""" + if not self.is_active: + self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") + return None + return self.parent_terminal.execute(command, connection_id=self.connection_uuid) + class RemoteTerminalConnection(TerminalClientConnection): """ @@ -78,8 +104,24 @@ class RemoteTerminalConnection(TerminalClientConnection): """Execute a given command on the remote Terminal.""" if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING: self.parent_terminal.sys_log.warning("Cannot process command as system not running") + return False + if not self.is_active: + self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") + return False # Send command to remote terminal to process. - return self.parent_terminal.send(payload=command, session_id=self.session_id) + + transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST + connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA + + payload: SSHPacket = SSHPacket( + transport_message=transport_message, + connection_message=connection_message, + connection_request_uuid=self.connection_request_id, + connection_uuid=self.connection_uuid, + ssh_command=command, + ) + + return self.parent_terminal.send(payload=payload, session_id=self.session_id) class Terminal(Service): @@ -138,7 +180,8 @@ class Terminal(Service): def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" command: str = request[0] - self.execute(command) + connection_id: str = request[1] + self.execute(command, connection_id=connection_id) return RequestResponse(status="success", data={}) def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: @@ -169,9 +212,14 @@ class Terminal(Service): return rm - def execute(self, command: List[Any]) -> RequestResponse: + def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) + valid_connection = self._check_client_connection(connection_id=connection_id) + if valid_connection: + return self.parent.apply_request(command) + else: + self.sys_log.error("Invalid connection ID provided") + return None def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection: """Create a new connection object and amend to list of active connections. @@ -180,7 +228,7 @@ class Terminal(Service): :param session_id: Session ID of the new local connection :return: TerminalClientConnection object """ - new_connection = TerminalClientConnection( + new_connection = LocalTerminalConnection( parent_terminal=self, connection_uuid=connection_uuid, session_id=session_id, @@ -340,7 +388,7 @@ class Terminal(Service): self._connections[connection_id] = client_connection self._client_connection_requests[connection_request_id] = client_connection - def receive(self, session_id: str, payload: Union[SSHPacket, Dict, List], **kwargs) -> bool: + def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool: """ Receive a payload from the Software Manager. @@ -400,6 +448,17 @@ class Terminal(Service): source_ip=source_ip, ) + elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: + # Requesting a command to be executed + self.sys_log.info("Received command to execute") + command = payload.ssh_command + valid_connection = self._check_client_connection(payload.connection_uuid) + self.sys_log.info(f"Connection uuid is {valid_connection}") + if valid_connection: + return self.execute(command, payload.connection_uuid) + else: + self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.") + if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": connection_id = payload["connection_id"] @@ -410,10 +469,6 @@ class Terminal(Service): else: self.sys_log.info("No Active connection held for received connection ID.") - if isinstance(payload, list): - # A request? For me? - self.execute(payload) - return True def _disconnect(self, connection_uuid: str) -> bool: @@ -426,16 +481,25 @@ class Terminal(Service): self.sys_log.warning("No remote connection present") return False - # session_id = self._connections[connection_uuid].session_id - connection: RemoteTerminalConnection = self._connections.pop(connection_uuid) - session_id = connection.session_id + connection = self._connections.pop(connection_uuid) + connection.is_active = False - software_manager: SoftwareManager = self.software_manager - software_manager.send_payload_to_session_manager( - payload={"type": "disconnect", "connection_id": connection_uuid}, dest_port=self.port, session_id=session_id - ) - self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}") - return True + if isinstance(connection, RemoteTerminalConnection): + # Send disconnect command via software manager + session_id = connection.session_id + + software_manager: SoftwareManager = self.software_manager + software_manager.send_payload_to_session_manager( + payload={"type": "disconnect", "connection_id": connection_uuid}, + dest_port=self.port, + session_id=session_id, + ) + self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}") + return True + + elif isinstance(connection, LocalTerminalConnection): + # No further action needed + return True def send( self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 7e98e501..cdd0ebb3 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -352,3 +352,27 @@ def test_multiple_remote_terminals_same_node(basic_network): remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") assert len(terminal_a._connections) == 10 + + +def test_terminal_rejects_commands_if_disconnect(basic_network): + """Test to check terminal will ignore commands from disconnected connections""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + terminal_b: Terminal = computer_b.software_manager.software.get("Terminal") + + remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11") + + assert len(terminal_a._connections) == 1 + assert len(terminal_b._connections) == 1 + + remote_connection.disconnect() + + assert len(terminal_a._connections) == 0 + assert len(terminal_b._connections) == 0 + + assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False + + assert not computer_b.software_manager.software.get("RansomwareScript") From 6d6f21a20a1a02bcb89caaef4aa4b20c18e6ee94 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 6 Aug 2024 19:14:53 +0100 Subject: [PATCH 95/95] #2706 - Additional assert on new test and a guard clause on LocalTerminalConnection.execute() to check that the Terminal service is running before sending a command --- src/primaite/simulator/system/services/terminal/terminal.py | 5 ++++- .../_primaite/_simulator/_system/_services/test_terminal.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 0ebae491..4be2c501 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -84,8 +84,11 @@ class LocalTerminalConnection(TerminalClientConnection): ip_address: str = "Local Connection" - def execute(self, command: Any) -> RequestResponse: + def execute(self, command: Any) -> Optional[RequestResponse]: """Execute a given command on local Terminal.""" + if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING: + self.parent_terminal.sys_log.warning("Cannot process command as system not running") + return None if not self.is_active: self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") return None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index cdd0ebb3..9286fa49 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -376,3 +376,5 @@ def test_terminal_rejects_commands_if_disconnect(basic_network): assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False assert not computer_b.software_manager.software.get("RansomwareScript") + + assert remote_connection.is_active is False