Merge branch 'refs/heads/dev' into feature/2768_enable-multi-port-listening-for-services-and-applications
This commit is contained in:
@@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
### Added
|
||||
- Random Number Generator Seeding by specifying a random number seed in the config file.
|
||||
- Implemented Terminal service class, providing a generic terminal simulation.
|
||||
|
||||
### Changed
|
||||
- Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
|
||||
@@ -22,7 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML
|
||||
- Agent logging for agents' internal decision logic
|
||||
- Action masking in all PrimAITE environments
|
||||
|
||||
### Changed
|
||||
- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed
|
||||
- Databases can no longer respond to request while performing a backup
|
||||
|
||||
@@ -49,3 +49,5 @@ fundamental network operations:
|
||||
5. **NTP (Network Time Protocol) Client:** Synchronises the host's clock with network time servers.
|
||||
|
||||
6. **Web Browser:** A simulated application that allows the host to request and display web content.
|
||||
|
||||
7. **Terminal:** A simulated service that allows the host to connect to remote hosts and execute commands.
|
||||
|
||||
173
docs/source/simulation_components/system/services/terminal.rst
Normal file
173
docs/source/simulation_components/system/services/terminal.rst
Normal file
@@ -0,0 +1,173 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _Terminal:
|
||||
|
||||
Terminal
|
||||
========
|
||||
|
||||
The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment.
|
||||
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically
|
||||
installed on Nodes when they are instantiated.
|
||||
|
||||
Key capabilities
|
||||
================
|
||||
|
||||
- Ensures packets are matched to an existing session
|
||||
- Simulates common Terminal processes/commands.
|
||||
- Leverages the Service base class for install/uninstall, status tracking etc.
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
- Pre-Installs on any `Node` (component with the exception of `Switches`).
|
||||
- Terminal Clients connect, execute commands and disconnect from remote nodes.
|
||||
- Ensures that users are logged in to the component before executing any commands.
|
||||
- Service runs on SSH port 22 by default.
|
||||
|
||||
Implementation
|
||||
==============
|
||||
|
||||
- Manages remote connections in a dictionary by session ID.
|
||||
- Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate.
|
||||
- Extends Service class.
|
||||
- A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
|
||||
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node.
|
||||
|
||||
Python
|
||||
""""""
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
|
||||
client = Computer(
|
||||
hostname="client",
|
||||
ip_address="192.168.10.21",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
|
||||
terminal: Terminal = client.software_manager.software.get("Terminal")
|
||||
|
||||
Creating Remote Terminal Connection
|
||||
"""""""""""""""""""""""""""
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
|
||||
|
||||
|
||||
network = Network()
|
||||
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_a.power_on()
|
||||
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
|
||||
|
||||
|
||||
Executing a basic application install command
|
||||
"""""""""""""""""""""""""""""""""
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
|
||||
|
||||
network = Network()
|
||||
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_a.power_on()
|
||||
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
|
||||
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"])
|
||||
|
||||
|
||||
|
||||
Creating a folder on a remote node
|
||||
""""""""""""""""""""""""""""""""
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
|
||||
|
||||
network = Network()
|
||||
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_a.power_on()
|
||||
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
|
||||
term_a_term_b_remote_connection.execute(["file_system", "create", "folder", "downloads"])
|
||||
|
||||
|
||||
Disconnect from Remote Node
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
|
||||
|
||||
network = Network()
|
||||
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_a.power_on()
|
||||
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
|
||||
term_a_term_b_remote_connection.disconnect()
|
||||
@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
"""Strict validation."""
|
||||
action_probabilities: Dict[int, float]
|
||||
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
|
||||
random_seed: Optional[int] = None
|
||||
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
|
||||
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
|
||||
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
|
||||
|
||||
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
|
||||
num_actions = len(action_space.action_map)
|
||||
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
|
||||
|
||||
# If seed not specified, set it to None so that numpy chooses a random one.
|
||||
settings.setdefault("random_seed")
|
||||
|
||||
# The random number seed for np.random is dependent on whether a random number seed is set
|
||||
# in the config file. If there is one it is processed by set_random_seed() in environment.py
|
||||
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
|
||||
self.settings = ProbabilisticAgent.Settings(**settings)
|
||||
|
||||
self.rng = np.random.default_rng(self.settings.random_seed)
|
||||
rng_seed = np.random.randint(0, 65535)
|
||||
self.rng = np.random.default_rng(rng_seed)
|
||||
|
||||
# convert probabilities from
|
||||
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
|
||||
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
|
||||
@@ -44,6 +44,7 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -57,6 +58,7 @@ SERVICE_TYPES_MAPPING = {
|
||||
"FTPServer": FTPServer,
|
||||
"NTPClient": NTPClient,
|
||||
"NTPServer": NTPServer,
|
||||
"Terminal": Terminal,
|
||||
}
|
||||
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
|
||||
|
||||
@@ -70,6 +72,8 @@ class PrimaiteGameOptions(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
seed: int = None
|
||||
"""Random number seed for RNGs."""
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[str]
|
||||
|
||||
209
src/primaite/notebooks/Terminal-Processing.ipynb
Normal file
209
src/primaite/notebooks/Terminal-Processing.ipynb
Normal file
@@ -0,0 +1,209 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Terminal Processing\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n",
|
||||
"\n",
|
||||
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.system.services.terminal.terminal import Terminal\n",
|
||||
"from primaite.simulator.network.container import Network\n",
|
||||
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
|
||||
"from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n",
|
||||
"from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n",
|
||||
"\n",
|
||||
"def basic_network() -> Network:\n",
|
||||
" \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n",
|
||||
" network = Network()\n",
|
||||
" node_a = Computer(hostname=\"node_a\", ip_address=\"192.168.0.10\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
|
||||
" node_a.power_on()\n",
|
||||
" node_b = Computer(hostname=\"node_b\", ip_address=\"192.168.0.11\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
|
||||
" node_b.power_on()\n",
|
||||
" network.connect(node_a.network_interface[1], node_b.network_interface[1])\n",
|
||||
" return network"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The terminal can be accessed from a `Node` via the `software_manager` as demonstrated below. \n",
|
||||
"\n",
|
||||
"In the example, we have a basic network consisting of two computers, connected to form a basic network."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network: Network = basic_network()\n",
|
||||
"computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n",
|
||||
"terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n",
|
||||
"computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n",
|
||||
"terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are remotely logging in to the 'admin' account on `node_b`, from `node_a`. \n",
|
||||
"If you are not logged in, any commands sent will be rejected by the remote.\n",
|
||||
"\n",
|
||||
"Remote Logins return a RemoteTerminalConnection object, which can be used for sending commands to the remote node. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Login to the remote (node_b) from local (node_a)\n",
|
||||
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=\"192.168.0.11\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can view all active connections to a terminal through use of the `show()` method"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"terminal_b.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The new connection object allows us to forward commands to be executed on the target node. The example below demonstrates how you can remotely install an application on the target node."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"computer_b.software_manager.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to create a downloads folder. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display the current state of the file system on computer_b\n",
|
||||
"computer_b.file_system.show()\n",
|
||||
"\n",
|
||||
"# Send command\n",
|
||||
"term_a_term_b_remote_connection.execute([\"file_system\", \"create\", \"folder\", \"downloads\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The resultant call to `computer_b.file_system.show()` shows that the new folder has been created."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"computer_b.file_system.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When finished, the connection can be closed by calling the `disconnect` function of the Remote Client object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display active connection\n",
|
||||
"terminal_a.show()\n",
|
||||
"terminal_b.show()\n",
|
||||
"\n",
|
||||
"term_a_term_b_remote_connection.disconnect()\n",
|
||||
"\n",
|
||||
"terminal_a.show()\n",
|
||||
"\n",
|
||||
"terminal_b.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
# Check torch is installed
|
||||
try:
|
||||
import torch as th
|
||||
except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
:param seed: int
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
random.seed(seed)
|
||||
# Seed numpy RNG
|
||||
np.random.seed(seed)
|
||||
# Seed the RNG for all devices (both CPU and CUDA)
|
||||
# if torch not installed don't set random seed.
|
||||
if sys.modules["torch"]:
|
||||
th.manual_seed(seed)
|
||||
th.backends.cudnn.deterministic = True
|
||||
th.backends.cudnn.benchmark = False
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
super().__init__()
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
|
||||
@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
super().reset() # Ensure PRNG seed is set everywhere
|
||||
if self.env.agent.action_masking:
|
||||
obs, *_ = self.env.reset(seed=seed)
|
||||
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
|
||||
|
||||
@@ -30,6 +30,7 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.processes.process import Process
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.simulator.system.software import IOSoftware, Software
|
||||
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
|
||||
from primaite.utils.validators import IPV4Address
|
||||
@@ -1541,6 +1542,11 @@ class Node(SimComponent):
|
||||
"""The Nodes User Session Manager."""
|
||||
return self.software_manager.software.get("UserSessionManager") # noqa
|
||||
|
||||
@property
|
||||
def terminal(self) -> Optional[Terminal]:
|
||||
"""The Nodes Terminal."""
|
||||
return self.software_manager.software.get("Terminal")
|
||||
|
||||
def local_login(self, username: str, password: str) -> Optional[str]:
|
||||
"""
|
||||
Attempt to log in to the node uas a local user.
|
||||
|
||||
@@ -21,6 +21,7 @@ from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.icmp.icmp import ICMP
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -298,6 +299,7 @@ class HostNode(Node):
|
||||
* DNS (Domain Name System) Client: Resolves domain names to IP addresses.
|
||||
* FTP (File Transfer Protocol) Client: Enables file transfers between the host and FTP servers.
|
||||
* NTP (Network Time Protocol) Client: Synchronizes the system clock with NTP servers.
|
||||
* Terminal Client: Handles SSH requests between HostNode and external components.
|
||||
|
||||
Applications:
|
||||
------------
|
||||
@@ -314,6 +316,7 @@ class HostNode(Node):
|
||||
"NMAP": NMAP,
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
"Terminal": Terminal,
|
||||
}
|
||||
"""List of system software that is automatically installed on nodes."""
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
from primaite.simulator.system.services.icmp.icmp import ICMP
|
||||
from primaite.simulator.system.services.terminal.terminal import Terminal
|
||||
from primaite.utils.validators import IPV4Address
|
||||
|
||||
|
||||
@@ -1203,6 +1204,7 @@ class Router(NetworkNode):
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
"Terminal": Terminal,
|
||||
}
|
||||
|
||||
num_ports: int
|
||||
|
||||
89
src/primaite/simulator/network/protocols/ssh.py
Normal file
89
src/primaite/simulator/network/protocols/ssh.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
|
||||
|
||||
class SSHTransportMessage(IntEnum):
|
||||
"""
|
||||
Enum list of Transport layer messages that can be handled by the simulation.
|
||||
|
||||
Each msg value is equivalent to the real-world.
|
||||
"""
|
||||
|
||||
SSH_MSG_USERAUTH_REQUEST = 50
|
||||
"""Requests User Authentication."""
|
||||
|
||||
SSH_MSG_USERAUTH_FAILURE = 51
|
||||
"""Indicates User Authentication failed."""
|
||||
|
||||
SSH_MSG_USERAUTH_SUCCESS = 52
|
||||
"""Indicates User Authentication was successful."""
|
||||
|
||||
SSH_MSG_SERVICE_REQUEST = 24
|
||||
"""Requests a service - such as executing a command."""
|
||||
|
||||
# These two msgs are invented for primAITE however are modelled on reality
|
||||
|
||||
SSH_MSG_SERVICE_FAILED = 25
|
||||
"""Indicates that the requested service failed."""
|
||||
|
||||
SSH_MSG_SERVICE_SUCCESS = 26
|
||||
"""Indicates that the requested service was successful."""
|
||||
|
||||
|
||||
class SSHConnectionMessage(IntEnum):
|
||||
"""Int Enum list of all SSH's connection protocol messages that can be handled by the simulation."""
|
||||
|
||||
SSH_MSG_CHANNEL_OPEN = 80
|
||||
"""Requests an open channel - Used in combination with SSH_MSG_USERAUTH_REQUEST."""
|
||||
|
||||
SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 81
|
||||
"""Confirms an open channel."""
|
||||
|
||||
SSH_MSG_CHANNEL_OPEN_FAILED = 82
|
||||
"""Indicates that channel opening failure."""
|
||||
|
||||
SSH_MSG_CHANNEL_DATA = 84
|
||||
"""Indicates that data is being sent through the channel."""
|
||||
|
||||
SSH_MSG_CHANNEL_CLOSE = 87
|
||||
"""Closes the channel."""
|
||||
|
||||
|
||||
class SSHUserCredentials(DataPacket):
|
||||
"""Hold Username and Password in SSH Packets."""
|
||||
|
||||
username: str
|
||||
"""Username for login"""
|
||||
|
||||
password: str
|
||||
"""Password for login"""
|
||||
|
||||
|
||||
class SSHPacket(DataPacket):
|
||||
"""Represents an SSHPacket."""
|
||||
|
||||
transport_message: SSHTransportMessage
|
||||
"""Message Transport Type"""
|
||||
|
||||
connection_message: SSHConnectionMessage
|
||||
"""Message Connection Status"""
|
||||
|
||||
user_account: Optional[SSHUserCredentials] = None
|
||||
"""User Account Credentials if passed"""
|
||||
|
||||
connection_request_uuid: Optional[str] = None
|
||||
"""Connection Request UUID used when establishing a remote connection"""
|
||||
|
||||
connection_uuid: Optional[str] = None
|
||||
"""Connection UUID used when validating a remote connection"""
|
||||
|
||||
ssh_output: Optional[RequestResponse] = None
|
||||
"""RequestResponse from Request Manager"""
|
||||
|
||||
ssh_command: Optional[list] = None
|
||||
"""Request String"""
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
@@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel):
|
||||
if self.client and self.is_active:
|
||||
self.client._disconnect(self.connection_id) # noqa
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""
|
||||
@@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""Connection ID to the Database Server."""
|
||||
client_connections: Dict[str, DatabaseClientConnection] = {}
|
||||
"""Keep track of active connections to Database Server."""
|
||||
_client_connection_requests: Dict[str, Optional[str]] = {}
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
|
||||
"""Dictionary of connection requests to Database Server."""
|
||||
connected: bool = False
|
||||
"""Boolean Value for whether connected to DB Server."""
|
||||
@@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
return False
|
||||
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
def _validate_client_connection_request(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return True if connection_id in self._client_connection_requests else False
|
||||
|
||||
@@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
valid_connection = self._check_client_connection(connection_id=connection_request_id)
|
||||
if valid_connection:
|
||||
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
|
||||
if valid_connection_request:
|
||||
database_client_connection = self._client_connection_requests.pop(connection_request_id)
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
)
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
if isinstance(database_client_connection, DatabaseClientConnection):
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
|
||||
f"Using connection id {database_client_connection}"
|
||||
)
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
else:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
|
||||
)
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
|
||||
f"due to unknown client-side connection request id"
|
||||
)
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
|
||||
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
@@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
|
||||
)
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
|
||||
@@ -191,12 +191,16 @@ class DatabaseService(Service):
|
||||
:return: Response to connection request containing success info.
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
|
||||
status_code = 500 # Default internal server error
|
||||
connection_id = None
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
|
||||
f"capacity."
|
||||
)
|
||||
if self.health_state_actual in [
|
||||
SoftwareHealthState.GOOD,
|
||||
SoftwareHealthState.FIXING,
|
||||
@@ -208,12 +212,16 @@ class DatabaseService(Service):
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id, session_id=session_id):
|
||||
status_code = 500
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
|
||||
f"returning status code 500"
|
||||
)
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
|
||||
f"(incorrect password), returning status code 401"
|
||||
)
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
523
src/primaite/simulator/system/services/terminal/terminal.py
Normal file
523
src/primaite/simulator/system/services/terminal/terminal.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.interface.request import RequestFormat, RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.protocols.ssh import (
|
||||
SSHConnectionMessage,
|
||||
SSHPacket,
|
||||
SSHTransportMessage,
|
||||
SSHUserCredentials,
|
||||
)
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
|
||||
|
||||
class TerminalClientConnection(BaseModel):
|
||||
"""
|
||||
TerminalClientConnection Class.
|
||||
|
||||
This class is used to record current User Connections to the Terminal class.
|
||||
"""
|
||||
|
||||
parent_terminal: Terminal
|
||||
"""The parent Node that this connection was created on."""
|
||||
|
||||
session_id: str = None
|
||||
"""Session ID that connection is linked to"""
|
||||
|
||||
connection_uuid: str = None
|
||||
"""Connection UUID"""
|
||||
|
||||
connection_request_id: str = None
|
||||
"""Connection request ID"""
|
||||
|
||||
time: datetime = None
|
||||
"""Timestamp connection was created."""
|
||||
|
||||
ip_address: IPv4Address
|
||||
"""Source IP of Connection"""
|
||||
|
||||
is_active: bool = True
|
||||
"""Flag to state whether the connection is active or not"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
@property
|
||||
def client(self) -> Optional[Terminal]:
|
||||
"""The Terminal that holds this connection."""
|
||||
return self.parent_terminal
|
||||
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect the session."""
|
||||
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: Any) -> bool:
|
||||
"""Execute a given command."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalTerminalConnection(TerminalClientConnection):
|
||||
"""
|
||||
LocalTerminalConnectionClass.
|
||||
|
||||
This class represents a local terminal when connected.
|
||||
"""
|
||||
|
||||
ip_address: str = "Local Connection"
|
||||
|
||||
def execute(self, command: Any) -> Optional[RequestResponse]:
|
||||
"""Execute a given command on local Terminal."""
|
||||
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
|
||||
return None
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return None
|
||||
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
|
||||
|
||||
|
||||
class RemoteTerminalConnection(TerminalClientConnection):
|
||||
"""
|
||||
RemoteTerminalConnection Class.
|
||||
|
||||
This class acts as broker between the terminal and remote.
|
||||
|
||||
"""
|
||||
|
||||
def execute(self, command: Any) -> bool:
|
||||
"""Execute a given command on the remote Terminal."""
|
||||
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
|
||||
return False
|
||||
if not self.is_active:
|
||||
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
|
||||
return False
|
||||
# Send command to remote terminal to process.
|
||||
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
connection_request_uuid=self.connection_request_id,
|
||||
connection_uuid=self.connection_uuid,
|
||||
ssh_command=command,
|
||||
)
|
||||
|
||||
return self.parent_terminal.send(payload=payload, session_id=self.session_id)
|
||||
|
||||
|
||||
class Terminal(Service):
|
||||
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
|
||||
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "Terminal"
|
||||
kwargs["port"] = Port.SSH
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Display the remote connections to this terminal instance in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
self.show_connections(markdown=markdown)
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
"""Initialise Request manager."""
|
||||
rm = super()._init_request_manager()
|
||||
rm.add_request(
|
||||
"send",
|
||||
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())),
|
||||
)
|
||||
|
||||
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
login = self._process_local_login(username=request[0], password=request[1])
|
||||
if login:
|
||||
return RequestResponse(status="success", data={})
|
||||
else:
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
|
||||
if login:
|
||||
return RequestResponse(status="success", data={})
|
||||
else:
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Execute an instruction."""
|
||||
command: str = request[0]
|
||||
connection_id: str = request[1]
|
||||
self.execute(command, connection_id=connection_id)
|
||||
return RequestResponse(status="success", data={})
|
||||
|
||||
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Logoff from connection."""
|
||||
connection_uuid = request[0]
|
||||
# TODO: Uncomment this when UserSessionManager merged.
|
||||
# self.parent.UserSessionManager.logoff(connection_uuid)
|
||||
self._disconnect(connection_uuid)
|
||||
|
||||
return RequestResponse(status="success", data={})
|
||||
|
||||
rm.add_request(
|
||||
"Login",
|
||||
request_type=RequestType(func=_login),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"Remote Login",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"Execute",
|
||||
request_type=RequestType(func=_execute_request),
|
||||
)
|
||||
|
||||
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
|
||||
|
||||
return rm
|
||||
|
||||
def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]:
|
||||
"""Execute a passed ssh command via the request manager."""
|
||||
valid_connection = self._check_client_connection(connection_id=connection_id)
|
||||
if valid_connection:
|
||||
return self.parent.apply_request(command)
|
||||
else:
|
||||
self.sys_log.error("Invalid connection ID provided")
|
||||
return None
|
||||
|
||||
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
|
||||
"""Create a new connection object and amend to list of active connections.
|
||||
|
||||
:param connection_uuid: Connection ID of the new local connection
|
||||
:param session_id: Session ID of the new local connection
|
||||
:return: TerminalClientConnection object
|
||||
"""
|
||||
new_connection = LocalTerminalConnection(
|
||||
parent_terminal=self,
|
||||
connection_uuid=connection_uuid,
|
||||
session_id=session_id,
|
||||
time=datetime.now(),
|
||||
)
|
||||
self._connections[connection_uuid] = new_connection
|
||||
self._client_connection_requests[connection_uuid] = new_connection
|
||||
|
||||
return new_connection
|
||||
|
||||
def login(
|
||||
self, username: str, password: str, ip_address: Optional[IPv4Address] = None
|
||||
) -> Optional[TerminalClientConnection]:
|
||||
"""Login to the terminal. Will attempt a remote login if ip_address is given, else local.
|
||||
|
||||
:param: username: Username used to connect to the remote node.
|
||||
:type: username: str
|
||||
|
||||
:param: password: Password used to connect to the remote node
|
||||
:type: password: str
|
||||
|
||||
:param: ip_address: Target Node IP address for login attempt. If None, login is assumed local.
|
||||
:type: ip_address: Optional[IPv4Address]
|
||||
"""
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.warning("Cannot login as service is not running.")
|
||||
return None
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
if ip_address:
|
||||
# Assuming that if IP is passed we are connecting to remote
|
||||
return self._send_remote_login(
|
||||
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
|
||||
)
|
||||
else:
|
||||
return self._process_local_login(username=username, password=password)
|
||||
|
||||
def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]:
|
||||
"""Local session login to terminal.
|
||||
|
||||
:param username: Username for login.
|
||||
:param password: Password for login.
|
||||
:return: boolean, True if successful, else False
|
||||
"""
|
||||
# TODO: Un-comment this when UserSessionManager is merged.
|
||||
# connection_uuid = self.parent.UserSessionManager.login(username=username, password=password)
|
||||
connection_uuid = str(uuid4())
|
||||
if connection_uuid:
|
||||
self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}")
|
||||
# Add new local session to list of connections and return
|
||||
return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection")
|
||||
else:
|
||||
self.sys_log.warning("Login failed, incorrect Username or Password")
|
||||
return None
|
||||
|
||||
def _validate_client_connection_request(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return True if connection_id in self._client_connection_requests else False
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return True if connection_id in self._connections else False
|
||||
|
||||
def _send_remote_login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
ip_address: IPv4Address,
|
||||
connection_request_id: str,
|
||||
is_reattempt: bool = False,
|
||||
) -> Optional[RemoteTerminalConnection]:
|
||||
"""Send a remote login attempt and connect to Node.
|
||||
|
||||
:param: username: Username used to connect to the remote node.
|
||||
:type: username: str
|
||||
|
||||
:param: password: Password used to connect to the remote node
|
||||
:type: password: str
|
||||
|
||||
:param: ip_address: Target Node IP address for login attempt.
|
||||
:type: ip_address: IPv4Address
|
||||
|
||||
:param: connection_request_id: Connection Request ID
|
||||
:type: connection_request_id: str
|
||||
|
||||
:param: is_reattempt: True if the request has been reattempted. Default False.
|
||||
:type: is_reattempt: Optional[bool]
|
||||
|
||||
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
|
||||
|
||||
"""
|
||||
self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}")
|
||||
if is_reattempt:
|
||||
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
|
||||
if valid_connection_request:
|
||||
remote_terminal_connection = self._client_connection_requests.pop(connection_request_id)
|
||||
if isinstance(remote_terminal_connection, RemoteTerminalConnection):
|
||||
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
|
||||
return remote_terminal_connection
|
||||
else:
|
||||
self.sys_log.warning(f"Connection request{connection_request_id} declined")
|
||||
return None
|
||||
else:
|
||||
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
|
||||
return None
|
||||
|
||||
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
|
||||
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
|
||||
|
||||
payload_contents = {
|
||||
"type": "login_request",
|
||||
"username": username,
|
||||
"password": password,
|
||||
"connection_request_id": connection_request_id,
|
||||
}
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=payload_contents,
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
user_account=user_details,
|
||||
connection_request_uuid=connection_request_id,
|
||||
)
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=ip_address, dest_port=self.port
|
||||
)
|
||||
return self._send_remote_login(
|
||||
username=username,
|
||||
password=password,
|
||||
ip_address=ip_address,
|
||||
is_reattempt=True,
|
||||
connection_request_id=connection_request_id,
|
||||
)
|
||||
|
||||
def _create_remote_connection(
|
||||
self, connection_id: str, connection_request_id: str, session_id: str, source_ip: str
|
||||
) -> None:
|
||||
"""Create a new TerminalClientConnection Object.
|
||||
|
||||
:param: connection_request_id: Connection Request ID
|
||||
:type: connection_request_id: str
|
||||
|
||||
:param: session_id: Session ID of connection.
|
||||
:type: session_id: str
|
||||
"""
|
||||
client_connection = RemoteTerminalConnection(
|
||||
parent_terminal=self,
|
||||
session_id=session_id,
|
||||
connection_uuid=connection_id,
|
||||
ip_address=source_ip,
|
||||
connection_request_id=connection_request_id,
|
||||
time=datetime.now(),
|
||||
)
|
||||
self._connections[connection_id] = client_connection
|
||||
self._client_connection_requests[connection_request_id] = client_connection
|
||||
|
||||
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
|
||||
:param payload: A payload to receive.
|
||||
:param session_id: The session id the payload relates to.
|
||||
:return: True.
|
||||
"""
|
||||
source_ip = kwargs["from_network_interface"].ip_address
|
||||
self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}")
|
||||
if isinstance(payload, SSHPacket):
|
||||
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
|
||||
# validate & add connection
|
||||
# TODO: uncomment this as part of 2781
|
||||
# connection_id = self.parent.UserSessionManager.login(username=username, password=password)
|
||||
connection_id = str(uuid4())
|
||||
if connection_id:
|
||||
connection_request_id = payload.connection_request_uuid
|
||||
username = payload.user_account.username
|
||||
password = payload.user_account.password
|
||||
print(f"Connection ID is: {connection_request_id}")
|
||||
self.sys_log.info(f"Connection authorised, session_id: {session_id}")
|
||||
self._create_remote_connection(
|
||||
connection_id=connection_id,
|
||||
connection_request_id=connection_request_id,
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
)
|
||||
|
||||
transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
|
||||
connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
|
||||
|
||||
payload_contents = {
|
||||
"type": "login_success",
|
||||
"username": username,
|
||||
"password": password,
|
||||
"connection_request_id": connection_request_id,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=payload_contents,
|
||||
transport_message=transport_message,
|
||||
connection_message=connection_message,
|
||||
connection_request_uuid=connection_request_id,
|
||||
connection_uuid=connection_id,
|
||||
)
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
|
||||
self.sys_log.info("Login Successful")
|
||||
self._create_remote_connection(
|
||||
connection_id=payload.connection_uuid,
|
||||
connection_request_id=payload.connection_request_uuid,
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
)
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
|
||||
# Requesting a command to be executed
|
||||
self.sys_log.info("Received command to execute")
|
||||
command = payload.ssh_command
|
||||
valid_connection = self._check_client_connection(payload.connection_uuid)
|
||||
self.sys_log.info(f"Connection uuid is {valid_connection}")
|
||||
if valid_connection:
|
||||
return self.execute(command, payload.connection_uuid)
|
||||
else:
|
||||
self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.")
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "disconnect":
|
||||
connection_id = payload["connection_id"]
|
||||
valid_id = self._check_client_connection(connection_id)
|
||||
if valid_id:
|
||||
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.")
|
||||
self._disconnect(payload["connection_id"])
|
||||
else:
|
||||
self.sys_log.info("No Active connection held for received connection ID.")
|
||||
|
||||
return True
|
||||
|
||||
def _disconnect(self, connection_uuid: str) -> bool:
|
||||
"""Disconnect from the remote.
|
||||
|
||||
:param connection_uuid: Connection ID that we want to disconnect.
|
||||
:return True if successful, False otherwise.
|
||||
"""
|
||||
if not self._connections:
|
||||
self.sys_log.warning("No remote connection present")
|
||||
return False
|
||||
|
||||
connection = self._connections.pop(connection_uuid)
|
||||
connection.is_active = False
|
||||
|
||||
if isinstance(connection, RemoteTerminalConnection):
|
||||
# Send disconnect command via software manager
|
||||
session_id = connection.session_id
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_uuid},
|
||||
dest_port=self.port,
|
||||
session_id=session_id,
|
||||
)
|
||||
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
|
||||
return True
|
||||
|
||||
elif isinstance(connection, LocalTerminalConnection):
|
||||
# No further action needed
|
||||
return True
|
||||
|
||||
def send(
|
||||
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a payload out from the Terminal.
|
||||
|
||||
:param payload: The payload to be sent.
|
||||
:param dest_up_address: The IP address of the payload destination.
|
||||
"""
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!")
|
||||
return False
|
||||
|
||||
self.sys_log.debug(f"Sending payload: {payload}")
|
||||
return super().send(
|
||||
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
@@ -316,7 +316,7 @@ class IOSoftware(Software):
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
@@ -333,11 +333,11 @@ class IOSoftware(Software):
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
f"{self.name}: Connection request ({connection_id}) declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
|
||||
@@ -100,6 +101,7 @@ def test_ray_single_agent_action_masking(monkeypatch):
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Fails due to being flaky when run in CI.")
|
||||
def test_ray_multi_agent_action_masking(monkeypatch):
|
||||
"""Check that Ray agents never take invalid actions when using MARL."""
|
||||
with open(MARL_PATH, "r") as f:
|
||||
|
||||
50
tests/integration_tests/game_layer/test_RNG_seed.py
Normal file
50
tests/integration_tests/game_layer/test_RNG_seed.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pprint import pprint
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def create_env():
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
return env
|
||||
|
||||
|
||||
def test_rng_seed_set(create_env):
|
||||
"""Test with RNG seed set."""
|
||||
env = create_env
|
||||
env.reset(seed=3)
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
env.reset(seed=3)
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
assert a == b
|
||||
|
||||
|
||||
def test_rng_seed_unset(create_env):
|
||||
"""Test with no RNG seed."""
|
||||
env = create_env
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
env.step(0)
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
|
||||
|
||||
assert a != b
|
||||
@@ -62,7 +62,6 @@ def test_probabilistic_agent():
|
||||
reward_function=reward_function,
|
||||
settings={
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
"random_seed": 120,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from primaite.simulator.network.protocols.ssh import (
|
||||
SSHConnectionMessage,
|
||||
SSHPacket,
|
||||
SSHTransportMessage,
|
||||
SSHUserCredentials,
|
||||
)
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def terminal_on_computer() -> Tuple[Terminal, Computer]:
|
||||
computer: Computer = Computer(
|
||||
hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
computer.power_on()
|
||||
terminal: Terminal = computer.software_manager.software.get("Terminal")
|
||||
|
||||
return terminal, computer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def basic_network() -> Network:
|
||||
network = Network()
|
||||
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_a.power_on()
|
||||
node_a.software_manager.get_open_ports()
|
||||
|
||||
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def wireless_wan_network():
|
||||
network = Network()
|
||||
|
||||
# Configure PC A
|
||||
pc_a = Computer(
|
||||
hostname="pc_a",
|
||||
ip_address="192.168.0.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.0.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc_a.power_on()
|
||||
network.add_node(pc_a)
|
||||
|
||||
# Configure Router 1
|
||||
router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace)
|
||||
router_1.power_on()
|
||||
network.add_node(router_1)
|
||||
|
||||
# Configure the connection between PC A and Router 1 port 2
|
||||
router_1.configure_router_interface("192.168.0.1", "255.255.255.0")
|
||||
network.connect(pc_a.network_interface[1], router_1.network_interface[2])
|
||||
|
||||
# Configure Router 1 ACLs
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
# add ACL rule to allow SSH traffic
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21)
|
||||
|
||||
# Configure PC B
|
||||
pc_b = Computer(
|
||||
hostname="pc_b",
|
||||
ip_address="192.168.2.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.2.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
pc_b.power_on()
|
||||
network.add_node(pc_b)
|
||||
|
||||
# Configure Router 2
|
||||
router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace)
|
||||
router_2.power_on()
|
||||
network.add_node(router_2)
|
||||
|
||||
# Configure the connection between PC B and Router 2 port 2
|
||||
router_2.configure_router_interface("192.168.2.1", "255.255.255.0")
|
||||
network.connect(pc_b.network_interface[1], router_2.network_interface[2])
|
||||
|
||||
# Configure Router 2 ACLs
|
||||
|
||||
# Configure the wireless connection between Router 1 port 1 and Router 2 port 1
|
||||
router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0")
|
||||
router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0")
|
||||
|
||||
router_1.route_table.add_route(
|
||||
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
|
||||
)
|
||||
|
||||
# Configure Route from Router 2 to PC A subnet
|
||||
router_2.route_table.add_route(
|
||||
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
|
||||
)
|
||||
|
||||
return pc_a, pc_b, router_1, router_2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def game_and_agent_fixture(game_and_agent):
|
||||
"""Create a game with a simple agent that can be controlled by the tests."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.start_up_duration = 3
|
||||
|
||||
return game, agent
|
||||
|
||||
|
||||
def test_terminal_creation(terminal_on_computer):
|
||||
terminal, computer = terminal_on_computer
|
||||
terminal.describe_state()
|
||||
|
||||
|
||||
def test_terminal_install_default():
|
||||
"""Terminal should be auto installed onto Nodes"""
|
||||
computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
|
||||
computer.power_on()
|
||||
|
||||
assert computer.software_manager.software.get("Terminal")
|
||||
|
||||
|
||||
def test_terminal_not_on_switch():
|
||||
"""Ensure terminal does not auto-install to switch"""
|
||||
test_switch = Switch(hostname="Test")
|
||||
|
||||
assert not test_switch.software_manager.software.get("Terminal")
|
||||
|
||||
|
||||
def test_terminal_send(basic_network):
|
||||
"""Test that Terminal can send valid commands."""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload="Test_Payload",
|
||||
transport_message=SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST,
|
||||
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
|
||||
user_account=SSHUserCredentials(username="username", password="password"),
|
||||
connection_request_uuid=str(uuid4()),
|
||||
)
|
||||
|
||||
assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
|
||||
|
||||
|
||||
def test_terminal_receive(basic_network):
|
||||
"""Test that terminal can receive and process commands"""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
folder_name = "Downloads"
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=["file_system", "create", "folder", folder_name],
|
||||
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
|
||||
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
|
||||
)
|
||||
|
||||
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
|
||||
username="username", password="password", ip_address="192.168.0.11"
|
||||
)
|
||||
|
||||
term_a_on_node_b.execute(["file_system", "create", "folder", folder_name])
|
||||
|
||||
# Assert that the Folder has been correctly created
|
||||
assert computer_b.file_system.get_folder(folder_name)
|
||||
|
||||
|
||||
def test_terminal_install(basic_network):
|
||||
"""Test that Terminal can successfully process an INSTALL request"""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=["software_manager", "application", "install", "RansomwareScript"],
|
||||
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
|
||||
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
|
||||
)
|
||||
|
||||
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
|
||||
username="username", password="password", ip_address="192.168.0.11"
|
||||
)
|
||||
|
||||
term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"])
|
||||
|
||||
assert computer_b.software_manager.software.get("RansomwareScript")
|
||||
|
||||
|
||||
def test_terminal_fail_when_closed(basic_network):
|
||||
"""Ensure Terminal won't attempt to send/receive when off"""
|
||||
network: Network = basic_network
|
||||
computer: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal: Terminal = computer.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
terminal.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
assert not terminal.login(
|
||||
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
|
||||
)
|
||||
|
||||
|
||||
def test_terminal_disconnect(basic_network):
|
||||
"""Test Terminal disconnects"""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
|
||||
|
||||
assert len(terminal_b._connections) == 0
|
||||
|
||||
term_a_on_term_b = terminal_a.login(
|
||||
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
|
||||
)
|
||||
|
||||
assert len(terminal_b._connections) == 1
|
||||
|
||||
term_a_on_term_b.disconnect()
|
||||
|
||||
assert len(terminal_b._connections) == 0
|
||||
|
||||
|
||||
def test_terminal_ignores_when_off(basic_network):
|
||||
"""Terminal should ignore commands when not running"""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
term_a_on_term_b: RemoteTerminalConnection = terminal_a.login(
|
||||
username="admin", password="Admin123!", ip_address="192.168.0.11"
|
||||
) # login to computer_b
|
||||
|
||||
terminal_a.operating_state = ServiceOperatingState.STOPPED
|
||||
|
||||
assert not term_a_on_term_b.execute(["software_manager", "application", "install", "RansomwareScript"])
|
||||
|
||||
|
||||
def test_computer_remote_login_to_router(wireless_wan_network):
|
||||
"""Test to confirm that a computer can SSH into a router."""
|
||||
pc_a, _, router_1, _ = wireless_wan_network
|
||||
|
||||
pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal")
|
||||
|
||||
assert len(pc_a_terminal._connections) == 0
|
||||
|
||||
pc_a_on_router_1 = pc_a_terminal.login(username="username", password="password", ip_address="192.168.1.1")
|
||||
|
||||
assert len(pc_a_terminal._connections) == 1
|
||||
|
||||
payload = ["software_manager", "application", "install", "RansomwareScript"]
|
||||
|
||||
pc_a_on_router_1.execute(payload)
|
||||
|
||||
assert router_1.software_manager.software.get("RansomwareScript")
|
||||
|
||||
|
||||
def test_router_remote_login_to_computer(wireless_wan_network):
|
||||
"""Test to confirm that a router can ssh into a computer."""
|
||||
pc_a, _, router_1, _ = wireless_wan_network
|
||||
|
||||
router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal")
|
||||
|
||||
assert len(router_1_terminal._connections) == 0
|
||||
|
||||
router_1_on_pc_a = router_1_terminal.login(username="username", password="password", ip_address="192.168.0.2")
|
||||
|
||||
assert len(router_1_terminal._connections) == 1
|
||||
|
||||
payload = ["software_manager", "application", "install", "RansomwareScript"]
|
||||
|
||||
router_1_on_pc_a.execute(payload)
|
||||
|
||||
assert pc_a.software_manager.software.get("RansomwareScript")
|
||||
|
||||
|
||||
def test_router_blocks_SSH_traffic(wireless_wan_network):
|
||||
"""Test to check that router will block SSH traffic if no ACL rule."""
|
||||
pc_a, _, router_1, _ = wireless_wan_network
|
||||
|
||||
# Remove rule that allows SSH traffic.
|
||||
router_1.acl.remove_rule(position=21)
|
||||
|
||||
pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal")
|
||||
|
||||
assert len(pc_a_terminal._connections) == 0
|
||||
|
||||
pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2")
|
||||
|
||||
assert len(pc_a_terminal._connections) == 0
|
||||
|
||||
|
||||
def test_SSH_across_network(wireless_wan_network):
|
||||
"""Test to show ability to SSH across a network."""
|
||||
pc_a, pc_b, router_1, router_2 = wireless_wan_network
|
||||
|
||||
terminal_a: Terminal = pc_a.software_manager.software.get("Terminal")
|
||||
terminal_b: Terminal = pc_b.software_manager.software.get("Terminal")
|
||||
|
||||
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21)
|
||||
|
||||
assert len(terminal_a._connections) == 0
|
||||
|
||||
terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2")
|
||||
|
||||
assert len(terminal_a._connections) == 1
|
||||
|
||||
|
||||
def test_multiple_remote_terminals_same_node(basic_network):
|
||||
"""Test to check that multiple remote terminals can be spawned by one node."""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
assert len(terminal_a._connections) == 0
|
||||
|
||||
# Spam login requests to terminal.
|
||||
for attempt in range(10):
|
||||
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
|
||||
|
||||
assert len(terminal_a._connections) == 10
|
||||
|
||||
|
||||
def test_terminal_rejects_commands_if_disconnect(basic_network):
|
||||
"""Test to check terminal will ignore commands from disconnected connections"""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
|
||||
|
||||
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
|
||||
|
||||
assert len(terminal_a._connections) == 1
|
||||
assert len(terminal_b._connections) == 1
|
||||
|
||||
remote_connection.disconnect()
|
||||
|
||||
assert len(terminal_a._connections) == 0
|
||||
assert len(terminal_b._connections) == 0
|
||||
|
||||
assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False
|
||||
|
||||
assert not computer_b.software_manager.software.get("RansomwareScript")
|
||||
|
||||
assert remote_connection.is_active is False
|
||||
Reference in New Issue
Block a user