Merge remote-tracking branch 'origin/4.0.0a1-dev' into feature/2869-Marek
This commit is contained in:
@@ -13,8 +13,8 @@ from primaite.simulator.system.services.database.database_service import Databas
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
|
||||
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
|
||||
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml"
|
||||
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml"
|
||||
|
||||
TestApplications = ["DummyApplication", "BroadcastTestClient"]
|
||||
|
||||
@@ -27,27 +27,27 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_default_fix_duration():
|
||||
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
|
||||
def test_default_fixing_duration():
|
||||
"""Test that software with no defined fixing duration in config uses the default fixing duration of 2."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 2
|
||||
assert database_client.config.fixing_duration == 2
|
||||
|
||||
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
|
||||
assert dns_client.fixing_duration == 2
|
||||
assert dns_client.config.fixing_duration == 2
|
||||
|
||||
|
||||
def test_fix_duration_set_from_config():
|
||||
"""Test to check that the fix duration set for applications and services works as intended."""
|
||||
def test_fixing_duration_set_from_config():
|
||||
"""Test to check that the fixing duration set for applications and services works as intended."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 3
|
||||
assert client_1.software_manager.software.get(service).config.fixing_duration == 3
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
# remove test applications from list
|
||||
@@ -55,27 +55,27 @@ def test_fix_duration_set_from_config():
|
||||
|
||||
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
|
||||
assert client_1.software_manager.software.get(application) is not None
|
||||
assert client_1.software_manager.software.get(application).fixing_duration == 1
|
||||
assert client_1.software_manager.software.get(application).config.fixing_duration == 1
|
||||
|
||||
|
||||
def test_fix_duration_for_one_item():
|
||||
"""Test that setting fix duration for one application does not affect other components."""
|
||||
def test_fixing_duration_for_one_item():
|
||||
"""Test that setting fixing duration for one application does not affect other components."""
|
||||
game = load_config(ONE_ITEM_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 2
|
||||
assert client_1.software_manager.software.get(service).config.fixing_duration == 2
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
# remove test applications from list
|
||||
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
assert client_1.software_manager.software.get(applications).fixing_duration == 2
|
||||
assert client_1.software_manager.software.get(applications).config.fixing_duration == 2
|
||||
|
||||
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 1
|
||||
assert database_client.config.fixing_duration == 1
|
||||
|
||||
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
|
||||
assert database_service.fixing_duration == 5
|
||||
assert database_service.config.fixing_duration == 5
|
||||
@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
@@ -31,6 +31,14 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
|
||||
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Application.ConfigSchema):
|
||||
"""ConfigSchema for ExtendedApplication."""
|
||||
|
||||
type: str = "ExtendedApplication"
|
||||
target_url: Optional[str] = None
|
||||
|
||||
config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema())
|
||||
|
||||
target_url: Optional[str] = None
|
||||
|
||||
domain_name_ip_address: Optional[IPv4Address] = None
|
||||
@@ -50,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
|
||||
kwargs["port"] = PORT_LOOKUP["HTTP"]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.target_url = self.config.target_url
|
||||
self.run()
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
|
||||
@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
@@ -17,13 +19,20 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ExtendedService(Service, identifier="extendedservice"):
|
||||
class ExtendedService(Service, identifier="ExtendedService"):
|
||||
"""
|
||||
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
|
||||
|
||||
This class inherits from the `Service` class and provides methods to simulate a SQL database.
|
||||
"""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for ExtendedService."""
|
||||
|
||||
type: str = "ExtendedService"
|
||||
|
||||
config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema())
|
||||
|
||||
password: Optional[str] = None
|
||||
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def test_nic_monitored_traffic(simulation):
|
||||
|
||||
# send a database query
|
||||
browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
|
||||
browser.target_url = f"http://arcd.com/"
|
||||
browser.config.target_url = f"http://arcd.com/"
|
||||
browser.get_webpage()
|
||||
|
||||
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
|
||||
|
||||
@@ -183,7 +183,7 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Remove rule that allows HTTP traffic across the network
|
||||
@@ -216,7 +216,7 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
@@ -416,7 +416,7 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
@@ -476,7 +476,7 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com
|
||||
|
||||
assert browser.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
browser.config.target_url = "http://www.example.com"
|
||||
agent.reward_function.register_component(comp, 0.7)
|
||||
|
||||
# Check that before trying to fetch the webpage, the reward is 0.0
|
||||
|
||||
@@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -14,9 +15,16 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class BroadcastTestService(Service):
|
||||
class BroadcastTestService(Service, identifier="BroadcastTestService"):
|
||||
"""A service for sending broadcast and unicast messages over a network."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for BroadcastTestService."""
|
||||
|
||||
type: str = "BroadcastTestService"
|
||||
|
||||
config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema())
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Set default service properties for broadcasting
|
||||
kwargs["name"] = "BroadcastService"
|
||||
@@ -46,6 +54,13 @@ class BroadcastTestService(Service):
|
||||
class BroadcastTestClient(Application, identifier="BroadcastTestClient"):
|
||||
"""A client application to receive broadcast and unicast messages."""
|
||||
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for BroadcastTestClient."""
|
||||
|
||||
type: str = "BroadcastTestClient"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema())
|
||||
|
||||
payloads_received: List = []
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -495,6 +495,12 @@ def test_c2_suite_yaml():
|
||||
|
||||
computer_b: Computer = yaml_network.get_node_by_hostname("node_b")
|
||||
c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon")
|
||||
c2_beacon.configure(
|
||||
c2_server_ip_address=c2_beacon.config.c2_server_ip_address,
|
||||
keep_alive_frequency=c2_beacon.config.keep_alive_frequency,
|
||||
masquerade_port=c2_beacon.config.masquerade_port,
|
||||
masquerade_protocol=c2_beacon.config.masquerade_protocol,
|
||||
)
|
||||
|
||||
assert c2_server.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ def test_database_service_fix(uc2_network):
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
for i in range(db_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
@@ -266,7 +266,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network):
|
||||
assert db_connection.query(sql="SELECT") is False
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
for i in range(db_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
@@ -308,7 +308,7 @@ def test_database_can_create_connection_while_fixing(uc2_network):
|
||||
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
for i in range(db_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -14,7 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
|
||||
class _DatabaseListener(Service):
|
||||
class _DatabaseListener(Service, identifier="_DatabaseListener"):
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSchema for _DatabaseListener."""
|
||||
|
||||
type: str = "_DatabaseListener"
|
||||
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
|
||||
|
||||
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
|
||||
name: str = "DatabaseListener"
|
||||
protocol: str = PROTOCOL_LOOKUP["TCP"]
|
||||
port: int = PORT_LOOKUP["NONE"]
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
web_browser_app.config.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
@@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
|
||||
web_browser_app.target_url = f"http://{web_server_ip}/"
|
||||
web_browser_app.config.target_url = f"http://{web_server_ip}/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
@@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server):
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
web_browser_app.config.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
@@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
web_browser_app.config.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.get_webpage() is True
|
||||
|
||||
# latest response should have status code 200
|
||||
|
||||
@@ -74,7 +74,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer,
|
||||
# Install Web Browser on computer
|
||||
computer.software_manager.install(WebBrowser)
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
web_browser.config.target_url = "http://arcd.com/users/"
|
||||
web_browser.run()
|
||||
|
||||
# Install DNS Client service on computer
|
||||
@@ -131,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network):
|
||||
|
||||
assert web_browser.get_webpage() is False
|
||||
|
||||
for i in range(database_service.fixing_duration + 1):
|
||||
for i in range(database_service.config.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert database_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
Reference in New Issue
Block a user