Merge branch 'dev' into feature/2777_set_RNG_seed

This commit is contained in:
Nick Todd
2024-08-05 11:12:30 +01:00
39 changed files with 1552 additions and 382 deletions

View File

@@ -45,7 +45,7 @@ def test_fix_duration_set_from_config():
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in SERVICE_TYPES_MAPPING:
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 3
@@ -53,7 +53,7 @@ def test_fix_duration_set_from_config():
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
for application in applications:
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
@@ -64,17 +64,13 @@ def test_fix_duration_for_one_item():
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
services = copy.copy(SERVICE_TYPES_MAPPING)
services.pop("DatabaseService")
for service in services:
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 2
# in config - applications take 1 timestep to fix
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
applications.remove("DatabaseClient")
for applications in applications:
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2

View File

@@ -9,9 +9,11 @@ from gymnasium import spaces
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
@@ -75,6 +77,18 @@ def test_nic(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
@@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation):
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
)
simulation.pre_timestep(0) # apply timestep to whole sim

View File

@@ -1,12 +1,14 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
def test_capture_nmne(uc2_network):
def test_capture_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
@@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
assert web_server_nic.nmne == {}
@@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network):
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
def test_describe_state_nmne(uc2_network):
def test_describe_state_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state.
@@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
web_server_nic_state = web_server_nic.describe_state()
@@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network):
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
def test_capture_nmne_observations(uc2_network):
def test_capture_nmne_observations(uc2_network: Network):
"""
Tests the NICObservation class's functionality within a simulated network environment.
@@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)

View File

@@ -0,0 +1,26 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from tests import TEST_ASSETS_ROOT
def test_users_from_config():
config_path = TEST_ASSETS_ROOT / "configs" / "basic_node_with_users.yaml"
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
client_1 = network.get_node_by_hostname("client_1")
user_manager: UserManager = client_1.software_manager.software["UserManager"]
assert len(user_manager.users) == 3
assert user_manager.users["jane.doe"].password == "1234"
assert user_manager.users["jane.doe"].is_admin
assert user_manager.users["john.doe"].password == "password_1"
assert not user_manager.users["john.doe"].is_admin

View File

@@ -0,0 +1,274 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
from uuid import uuid4
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import User
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
@pytest.fixture(scope="function")
def client_server_network() -> Tuple[Computer, Server, Network]:
network = Network()
client = Computer(
hostname="client",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client.power_on()
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
network.connect(client.network_interface[1], server.network_interface[1])
return client, server, network
def test_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
def test_login_count_increases(client_server_network):
client, server, network = client_server_network
admin_user: User = client.user_manager.users["admin"]
assert admin_user.num_of_logins == 0
client.user_session_manager.local_login(username="admin", password="admin")
assert admin_user.num_of_logins == 1
client.user_session_manager.local_login(username="admin", password="admin")
# shouldn't change as user is already logged in
assert admin_user.num_of_logins == 1
client.user_session_manager.local_logout()
client.user_session_manager.local_login(username="admin", password="admin")
assert admin_user.num_of_logins == 2
def test_local_login_failure(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert not client.user_session_manager.local_user_logged_in
def test_new_user_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_manager.add_user(username="jane.doe", password="12345")
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
def test_new_local_login_clears_previous_login(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
client.user_manager.add_user(username="jane.doe", password="12345")
new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "jane.doe"
assert new_session_id != current_session_id
def test_new_local_login_attempt_same_uses_persists(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
new_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
assert new_session_id == current_session_id
def test_remote_login_success(client_server_network):
# partial test for now until we get the terminal application in so that amn actual remote connection can be made
client, server, network = client_server_network
assert not server.user_session_manager.remote_sessions
remote_session_id = server.user_session_manager.remote_login(
username="admin", password="admin", remote_ip_address="192.168.1.10"
)
assert server.user_session_manager.validate_remote_session_uuid(remote_session_id)
server.user_session_manager.remote_logout(remote_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id)
def test_remote_login_failure(client_server_network):
# partial test for now until we get the terminal application in so that amn actual remote connection can be made
client, server, network = client_server_network
assert not server.user_session_manager.remote_sessions
remote_session_id = server.user_session_manager.remote_login(
username="jane.doe", password="12345", remote_ip_address="192.168.1.10"
)
assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id)
def test_new_user_remote_login_success(client_server_network):
client, server, network = client_server_network
server.user_manager.add_user(username="jane.doe", password="12345")
remote_session_id = server.user_session_manager.remote_login(
username="jane.doe", password="12345", remote_ip_address="192.168.1.10"
)
assert server.user_session_manager.validate_remote_session_uuid(remote_session_id)
server.user_session_manager.remote_logout(remote_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id)
def test_max_remote_sessions_same_user(client_server_network):
client, server, network = client_server_network
remote_session_ids = [
server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10")
for _ in range(server.user_session_manager.max_remote_sessions)
]
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
def test_max_remote_sessions_different_users(client_server_network):
client, server, network = client_server_network
remote_session_ids = []
for i in range(server.user_session_manager.max_remote_sessions):
username = str(uuid4())
password = "12345"
server.user_manager.add_user(username=username, password=password)
remote_session_ids.append(
server.user_session_manager.remote_login(
username=username, password=password, remote_ip_address="192.168.1.10"
)
)
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
def test_max_remote_sessions_limit_reached(client_server_network):
client, server, network = client_server_network
remote_session_ids = [
server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10")
for _ in range(server.user_session_manager.max_remote_sessions)
]
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
assert len(server.user_session_manager.remote_sessions) == server.user_session_manager.max_remote_sessions
fourth_attempt_session_id = server.user_session_manager.remote_login(
username="admin", password="admin", remote_ip_address="192.168.1.10"
)
assert not server.user_session_manager.validate_remote_session_uuid(fourth_attempt_session_id)
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
def test_single_remote_logout_others_persist(client_server_network):
client, server, network = client_server_network
server.user_manager.add_user(username="jane.doe", password="12345")
server.user_manager.add_user(username="john.doe", password="12345")
admin_session_id = server.user_session_manager.remote_login(
username="admin", password="admin", remote_ip_address="192.168.1.10"
)
jane_session_id = server.user_session_manager.remote_login(
username="jane.doe", password="12345", remote_ip_address="192.168.1.10"
)
john_session_id = server.user_session_manager.remote_login(
username="john.doe", password="12345", remote_ip_address="192.168.1.10"
)
server.user_session_manager.remote_logout(admin_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id)
assert server.user_session_manager.validate_remote_session_uuid(jane_session_id)
assert server.user_session_manager.validate_remote_session_uuid(john_session_id)
server.user_session_manager.remote_logout(jane_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id)
assert server.user_session_manager.validate_remote_session_uuid(john_session_id)
server.user_session_manager.remote_logout(john_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(john_session_id)