Merge remote-tracking branch 'origin/4.0.0a1-dev' into feature/2869-Marek

This commit is contained in:
Marek Wolan
2025-01-20 10:39:20 +00:00
46 changed files with 559 additions and 252 deletions

View File

@@ -13,8 +13,8 @@ from primaite.simulator.system.services.database.database_service import Databas
from primaite.simulator.system.services.dns.dns_client import DNSClient
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml"
TestApplications = ["DummyApplication", "BroadcastTestClient"]
@@ -27,27 +27,27 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
return PrimaiteGame.from_config(cfg)
def test_default_fix_duration():
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
def test_default_fixing_duration():
"""Test that software with no defined fixing duration in config uses the default fixing duration of 2."""
game = load_config(TEST_CONFIG)
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 2
assert database_client.config.fixing_duration == 2
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
assert dns_client.fixing_duration == 2
assert dns_client.config.fixing_duration == 2
def test_fix_duration_set_from_config():
"""Test to check that the fix duration set for applications and services works as intended."""
def test_fixing_duration_set_from_config():
"""Test to check that the fixing duration set for applications and services works as intended."""
game = load_config(TEST_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 3
assert client_1.software_manager.software.get(service).config.fixing_duration == 3
# in config - applications take 1 timestep to fix
# remove test applications from list
@@ -55,27 +55,27 @@ def test_fix_duration_set_from_config():
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
assert client_1.software_manager.software.get(application).config.fixing_duration == 1
def test_fix_duration_for_one_item():
"""Test that setting fix duration for one application does not affect other components."""
def test_fixing_duration_for_one_item():
"""Test that setting fixing duration for one application does not affect other components."""
game = load_config(ONE_ITEM_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 2
assert client_1.software_manager.software.get(service).config.fixing_duration == 2
# in config - applications take 1 timestep to fix
# remove test applications from list
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2
assert client_1.software_manager.software.get(applications).config.fixing_duration == 2
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 1
assert database_client.config.fixing_duration == 1
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
assert database_service.fixing_duration == 5
assert database_service.config.fixing_duration == 5

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from typing import Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from primaite import getLogger
from primaite.interface.request import RequestResponse
@@ -31,6 +31,14 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for ExtendedApplication."""
type: str = "ExtendedApplication"
target_url: Optional[str] = None
config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema())
target_url: Optional[str] = None
domain_name_ip_address: Optional[IPv4Address] = None
@@ -50,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
kwargs["port"] = PORT_LOOKUP["HTTP"]
super().__init__(**kwargs)
self.target_url = self.config.target_url
self.run()
def _init_request_manager(self) -> RequestManager:

View File

@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
@@ -17,13 +19,20 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class ExtendedService(Service, identifier="extendedservice"):
class ExtendedService(Service, identifier="ExtendedService"):
"""
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ExtendedService."""
type: str = "ExtendedService"
config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema())
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""

View File

@@ -191,7 +191,7 @@ def test_nic_monitored_traffic(simulation):
# send a database query
browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
browser.target_url = f"http://arcd.com/"
browser.config.target_url = f"http://arcd.com/"
browser.get_webpage()
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")

View File

@@ -183,7 +183,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Remove rule that allows HTTP traffic across the network
@@ -216,7 +216,7 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
@@ -416,7 +416,7 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
@@ -476,7 +476,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com
assert browser.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -29,7 +29,7 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
agent.reward_function.register_component(comp, 0.7)
# Check that before trying to fetch the webpage, the reward is 0.0

View File

@@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Tuple
import pytest
from pydantic import Field
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -14,9 +15,16 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
class BroadcastTestService(Service):
class BroadcastTestService(Service, identifier="BroadcastTestService"):
"""A service for sending broadcast and unicast messages over a network."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestService."""
type: str = "BroadcastTestService"
config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema())
def __init__(self, **kwargs):
# Set default service properties for broadcasting
kwargs["name"] = "BroadcastService"
@@ -46,6 +54,13 @@ class BroadcastTestService(Service):
class BroadcastTestClient(Application, identifier="BroadcastTestClient"):
"""A client application to receive broadcast and unicast messages."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestClient."""
type: str = "BroadcastTestClient"
config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema())
payloads_received: List = []
def __init__(self, **kwargs):

View File

@@ -495,6 +495,12 @@ def test_c2_suite_yaml():
computer_b: Computer = yaml_network.get_node_by_hostname("node_b")
c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon")
c2_beacon.configure(
c2_server_ip_address=c2_beacon.config.c2_server_ip_address,
keep_alive_frequency=c2_beacon.config.keep_alive_frequency,
masquerade_port=c2_beacon.config.masquerade_port,
masquerade_protocol=c2_beacon.config.masquerade_protocol,
)
assert c2_server.operating_state == ApplicationOperatingState.RUNNING

View File

@@ -232,7 +232,7 @@ def test_database_service_fix(uc2_network):
assert db_service.health_state_actual == SoftwareHealthState.FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
@@ -266,7 +266,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network):
assert db_connection.query(sql="SELECT") is False
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD
@@ -308,7 +308,7 @@ def test_database_can_create_connection_while_fixing(uc2_network):
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -14,7 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP
from tests import TEST_ASSETS_ROOT
class _DatabaseListener(Service):
class _DatabaseListener(Service, identifier="_DatabaseListener"):
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for _DatabaseListener."""
type: str = "_DatabaseListener"
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
name: str = "DatabaseListener"
protocol: str = PROTOCOL_LOOKUP["TCP"]
port: int = PORT_LOOKUP["NONE"]

View File

@@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://{web_server_ip}/"
web_browser_app.config.target_url = f"http://{web_server_ip}/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.get_webpage() is True
# latest response should have status code 200

View File

@@ -74,7 +74,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer,
# Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
web_browser.config.target_url = "http://arcd.com/users/"
web_browser.run()
# Install DNS Client service on computer
@@ -131,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network):
assert web_browser.get_webpage() is False
for i in range(database_service.fixing_duration + 1):
for i in range(database_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert database_service.health_state_actual == SoftwareHealthState.GOOD