#2887 - Resolve conflicts from merge

This commit is contained in:
Charlie Crane
2025-01-23 09:17:27 +00:00
174 changed files with 7047 additions and 8412 deletions

View File

@@ -13,8 +13,8 @@ from primaite.simulator.system.services.database.database_service import Databas
from primaite.simulator.system.services.dns.dns_client import DNSClient
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fixing_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fixing_duration_one_item.yaml"
TestApplications = ["DummyApplication", "BroadcastTestClient"]
@@ -27,27 +27,27 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
return PrimaiteGame.from_config(cfg)
def test_default_fix_duration():
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
def test_default_fixing_duration():
"""Test that software with no defined fixing duration in config uses the default fixing duration of 2."""
game = load_config(TEST_CONFIG)
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 2
assert database_client.config.fixing_duration == 2
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
assert dns_client.fixing_duration == 2
assert dns_client.config.fixing_duration == 2
def test_fix_duration_set_from_config():
"""Test to check that the fix duration set for applications and services works as intended."""
def test_fixing_duration_set_from_config():
"""Test to check that the fixing duration set for applications and services works as intended."""
game = load_config(TEST_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 3
assert client_1.software_manager.software.get(service).config.fixing_duration == 3
# in config - applications take 1 timestep to fix
# remove test applications from list
@@ -55,27 +55,27 @@ def test_fix_duration_set_from_config():
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
assert client_1.software_manager.software.get(application).config.fixing_duration == 1
def test_fix_duration_for_one_item():
"""Test that setting fix duration for one application does not affect other components."""
def test_fixing_duration_for_one_item():
"""Test that setting fixing duration for one application does not affect other components."""
game = load_config(ONE_ITEM_CONFIG)
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 2
assert client_1.software_manager.software.get(service).config.fixing_duration == 2
# in config - applications take 1 timestep to fix
# remove test applications from list
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2
assert client_1.software_manager.software.get(applications).config.fixing_duration == 2
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
assert database_client.fixing_duration == 1
assert database_client.config.fixing_duration == 1
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
assert database_service.fixing_duration == 5
assert database_service.config.fixing_duration == 5

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
from typing import Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from primaite import getLogger
from primaite.interface.request import RequestResponse
@@ -31,6 +31,14 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
class ConfigSchema(Application.ConfigSchema):
"""ConfigSchema for ExtendedApplication."""
type: str = "ExtendedApplication"
target_url: Optional[str] = None
config: "ExtendedApplication.ConfigSchema" = Field(default_factory=lambda: ExtendedApplication.ConfigSchema())
target_url: Optional[str] = None
domain_name_ip_address: Optional[IPv4Address] = None
@@ -50,6 +58,7 @@ class ExtendedApplication(Application, identifier="ExtendedApplication"):
kwargs["port"] = PORT_LOOKUP["HTTP"]
super().__init__(**kwargs)
self.target_url = self.config.target_url
self.run()
def _init_request_manager(self) -> RequestManager:

View File

@@ -3,6 +3,8 @@ from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from uuid import uuid4
from pydantic import Field
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
@@ -17,13 +19,20 @@ from primaite.utils.validation.port import PORT_LOOKUP
_LOGGER = getLogger(__name__)
class ExtendedService(Service, identifier="extendedservice"):
class ExtendedService(Service, identifier="ExtendedService"):
"""
A copy of DatabaseService that uses the extension framework instead of being part of PrimAITE.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for ExtendedService."""
type: str = "ExtendedService"
config: "ExtendedService.ConfigSchema" = Field(default_factory=lambda: ExtendedService.ConfigSchema())
password: Optional[str] = None
"""Password that needs to be provided by clients if they want to connect to the DatabaseService."""

View File

@@ -33,22 +33,22 @@ def test_application_cannot_perform_actions_unless_running(game_and_agent_fixtur
browser.close()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = ("node_application_scan", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0})
action = ("node_application_close", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0})
action = ("node_application_fix", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})
action = ("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED

View File

@@ -46,23 +46,21 @@ def test_c2_beacon_default(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgen
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
action = (
"NODE_APPLICATION_INSTALL",
{"node_id": 1, "application_name": "C2Beacon"},
"node_application_install",
{"node_name": "server_1", "application_name": "C2Beacon"},
)
agent.store_action(action)
game.step()
assert agent.history[-1].response.status == "success"
action = (
"CONFIGURE_C2_BEACON",
"configure_c2_beacon",
{
"node_id": 1,
"config": {
"c2_server_ip_address": "10.0.1.2",
"keep_alive_frequency": 5,
"masquerade_protocol": "TCP",
"masquerade_port": "HTTP",
},
"node_name": "server_1",
"c2_server_ip_address": "10.0.1.2",
"keep_alive_frequency": 5,
"masquerade_protocol": "TCP",
"masquerade_port": "HTTP",
},
)
agent.store_action(action)
@@ -70,8 +68,8 @@ def test_c2_beacon_default(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgen
assert agent.history[-1].response.status == "success"
action = (
"NODE_APPLICATION_EXECUTE",
{"node_id": 1, "application_id": 0},
"node_application_execute",
{"node_name": "server_1", "application_name": "C2Beacon"},
)
agent.store_action(action)
game.step()
@@ -103,14 +101,12 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA
# C2 Action 1: Installing the RansomwareScript & Database client via Terminal
action = (
"C2_SERVER_TERMINAL_COMMAND",
"c2_server_terminal_command",
{
"node_id": 0,
"node_name": "client_1",
"ip_address": None,
"account": {
"username": "admin",
"password": "admin",
},
"username": "admin",
"password": "admin",
"commands": [
["software_manager", "application", "install", "RansomwareScript"],
["software_manager", "application", "install", "DatabaseClient"],
@@ -122,10 +118,11 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA
assert agent.history[-1].response.status == "success"
action = (
"C2_SERVER_RANSOMWARE_CONFIGURE",
"c2_server_ransomware_configure",
{
"node_id": 0,
"config": {"server_ip_address": "10.0.2.3", "payload": "ENCRYPT"},
"node_name": "client_1",
"server_ip_address": "10.0.2.3",
"payload": "ENCRYPT",
},
)
agent.store_action(action)
@@ -134,16 +131,16 @@ def test_c2_server_ransomware(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA
# Stepping a few timesteps to allow for the RansowmareScript to finish installing.
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
game.step()
game.step()
action = (
"C2_SERVER_RANSOMWARE_LAUNCH",
"c2_server_ransomware_launch",
{
"node_id": 0,
"node_name": "client_1",
},
)
agent.store_action(action)
@@ -181,17 +178,15 @@ def test_c2_server_data_exfiltration(game_and_agent_fixture: Tuple[PrimaiteGame,
# C2 Action: Data exfiltrate.
action = (
"C2_SERVER_DATA_EXFILTRATE",
"c2_server_data_exfiltrate",
{
"node_id": 0,
"node_name": "client_1",
"target_file_name": "database.db",
"target_folder_name": "database",
"exfiltration_folder_name": "spoils",
"target_ip_address": "10.0.2.3",
"account": {
"username": "admin",
"password": "admin",
},
"username": "admin",
"password": "admin",
},
)
agent.store_action(action)

View File

@@ -4,7 +4,7 @@ from ipaddress import IPv4Address
import pytest
from pydantic import ValidationError
from primaite.game.agent.actions import (
from primaite.game.agent.actions.software import (
ConfigureDatabaseClientAction,
ConfigureDoSBotAction,
ConfigureRansomwareScriptAction,
@@ -27,7 +27,6 @@ class TestConfigureDatabaseAction:
def test_configure_ip_password(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
@@ -35,13 +34,11 @@ class TestConfigureDatabaseAction:
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
action = (
"CONFIGURE_DATABASE_CLIENT",
"configure_database_client",
{
"node_id": 0,
"config": {
"server_ip_address": "192.168.1.99",
"server_password": "admin123",
},
"node_name": "client_1",
"server_ip_address": "192.168.1.99",
"server_password": "admin123",
},
)
agent.store_action(action)
@@ -53,7 +50,6 @@ class TestConfigureDatabaseAction:
def test_configure_ip(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
@@ -61,12 +57,10 @@ class TestConfigureDatabaseAction:
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
action = (
"CONFIGURE_DATABASE_CLIENT",
"configure_database_client",
{
"node_id": 0,
"config": {
"server_ip_address": "192.168.1.99",
},
"node_name": "client_1",
"server_ip_address": "192.168.1.99",
},
)
agent.store_action(action)
@@ -78,7 +72,6 @@ class TestConfigureDatabaseAction:
def test_configure_password(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
@@ -87,12 +80,10 @@ class TestConfigureDatabaseAction:
old_ip = db_client.server_ip_address
action = (
"CONFIGURE_DATABASE_CLIENT",
"configure_database_client",
{
"node_id": 0,
"config": {
"server_password": "admin123",
},
"node_name": "client_1",
"server_password": "admin123",
},
)
agent.store_action(action)
@@ -120,9 +111,6 @@ class TestConfigureRansomwareScriptAction:
def test_configure_ip_password(self, game_and_agent, config):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
agent.action_manager
)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
@@ -134,8 +122,8 @@ class TestConfigureRansomwareScriptAction:
old_payload = ransomware_script.payload
action = (
"CONFIGURE_RANSOMWARE_SCRIPT",
{"node_id": 0, "config": config},
"configure_ransomware_script",
{"node_name": "client_1", **config},
)
agent.store_action(action)
game.step()
@@ -151,18 +139,15 @@ class TestConfigureRansomwareScriptAction:
def test_invalid_config(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
agent.action_manager
)
# make sure there is a database client on this node
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(RansomwareScript)
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
action = (
"CONFIGURE_RANSOMWARE_SCRIPT",
"configure_ransomware_script",
{
"node_id": 0,
"node_name": "client_1",
"config": {"server_password": "admin123", "bad_option": 70},
},
)
@@ -172,28 +157,25 @@ class TestConfigureRansomwareScriptAction:
class TestConfigureDoSBot:
def test_configure_DoSBot(self, game_and_agent):
def test_configure_dos_bot(self, game_and_agent):
game, agent = game_and_agent
agent: ControlledAgent
agent.action_manager.actions["CONFIGURE_DOSBOT"] = ConfigureDoSBotAction(agent.action_manager)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.install(DoSBot)
dos_bot: DoSBot = client_1.software_manager.software["DoSBot"]
action = (
"CONFIGURE_DOSBOT",
"configure_dos_bot",
{
"node_id": 0,
"config": {
"target_ip_address": "192.168.1.99",
"target_port": "POSTGRES_SERVER",
"payload": "HACC",
"repeat": False,
"port_scan_p_of_success": 0.875,
"dos_intensity": 0.75,
"max_sessions": 50,
},
"node_name": "client_1",
"target_ip_address": "192.168.1.99",
"target_port": "POSTGRES_SERVER",
"payload": "HACC",
"repeat": False,
"port_scan_p_of_success": 0.875,
"dos_intensity": 0.75,
"max_sessions": 50,
},
)
agent.store_action(action)
@@ -239,7 +221,7 @@ class TestConfigureYAML:
assert db_client.server_password == "correct_password"
assert db_client.connect()
def test_configure_ransomware_script(self):
def test_c2_server_ransomware_configure(self):
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
client_2 = env.game.simulation.network.get_node_by_hostname("client_2")
assert client_2.software_manager.software.get("RansomwareScript") is None

View File

@@ -33,8 +33,8 @@ def test_create_file(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
assert client_1.file_system.get_file(folder_name=random_folder, file_name=random_file) is None
action = (
"NODE_FILE_CREATE",
{"node_id": 0, "folder_name": random_folder, "file_name": random_file},
"node_file_create",
{"node_name": "client_1", "folder_name": random_folder, "file_name": random_file},
)
agent.store_action(action)
game.step()
@@ -51,8 +51,8 @@ def test_file_delete_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge
assert file.deleted is False
action = (
"NODE_FILE_DELETE",
{"node_id": 0, "folder_id": 0, "file_id": 0},
"node_file_delete",
{"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
)
agent.store_action(action)
game.step()
@@ -69,11 +69,11 @@ def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent
file.corrupt()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
assert file.visible_health_status == FileSystemItemHealthStatus.NONE
action = (
"NODE_FILE_SCAN",
{"node_id": 0, "folder_id": 0, "file_id": 0},
"node_file_scan",
{"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
)
agent.store_action(action)
game.step()
@@ -93,8 +93,8 @@ def test_file_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FILE_REPAIR",
{"node_id": 0, "folder_id": 0, "file_id": 0},
"node_file_repair",
{"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
)
agent.store_action(action)
game.step()
@@ -113,8 +113,8 @@ def test_file_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAg
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FILE_RESTORE",
{"node_id": 0, "folder_id": 0, "file_id": 0},
"node_file_restore",
{"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
)
agent.store_action(action)
game.step()
@@ -132,8 +132,8 @@ def test_file_corrupt_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAg
assert file.health_status == FileSystemItemHealthStatus.GOOD
action = (
"NODE_FILE_CORRUPT",
{"node_id": 0, "folder_id": 0, "file_id": 0},
"node_file_corrupt",
{"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
)
agent.store_action(action)
game.step()
@@ -150,8 +150,8 @@ def test_file_access_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge
assert file.num_access == 0
action = (
"NODE_FILE_ACCESS",
{"node_id": 0, "folder_name": file.folder_name, "file_name": file.name},
"node_file_access",
{"node_name": "client_1", "folder_name": file.folder_name, "file_name": file.name},
)
agent.store_action(action)
game.step()

View File

@@ -32,9 +32,9 @@ def test_create_folder(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
assert client_1.file_system.get_folder(folder_name=random_folder) is None
action = (
"NODE_FOLDER_CREATE",
"node_folder_create",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": random_folder,
},
)
@@ -52,18 +52,18 @@ def test_folder_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAge
folder = client_1.file_system.get_folder(folder_name="downloads")
assert folder.health_status == FileSystemItemHealthStatus.GOOD
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
folder.corrupt()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
action = (
"NODE_FOLDER_SCAN",
"node_folder_scan",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads
"node_name": "client_1", # client_1,
"folder_name": "downloads", # downloads
},
)
agent.store_action(action)
@@ -87,10 +87,10 @@ def test_folder_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyA
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FOLDER_REPAIR",
"node_folder_repair",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads
"node_name": "client_1", # client_1,
"folder_name": "downloads", # downloads
},
)
agent.store_action(action)
@@ -111,10 +111,10 @@ def test_folder_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FOLDER_RESTORE",
"node_folder_restore",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads
"node_name": "client_1", # client_1,
"folder_name": "downloads", # downloads
},
)
agent.store_action(action)

View File

@@ -29,10 +29,10 @@ def test_nic_cannot_be_turned_off_if_not_on(game_and_agent_fixture: Tuple[Primai
assert nic.enabled is False
action = (
"HOST_NIC_DISABLE",
"host_nic_disable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -50,10 +50,10 @@ def test_nic_cannot_be_turned_on_if_already_on(game_and_agent_fixture: Tuple[Pri
assert nic.enabled
action = (
"HOST_NIC_ENABLE",
"host_nic_enable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -71,10 +71,10 @@ def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[Pr
assert nic.enabled
action = (
"HOST_NIC_DISABLE",
"host_nic_disable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -83,10 +83,10 @@ def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[Pr
assert nic.enabled is False
action = (
"HOST_NIC_ENABLE",
"host_nic_enable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)

View File

@@ -29,28 +29,28 @@ def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, Proxy
assert client_1.operating_state == NodeOperatingState.ON
# turn it off
action = ("NODE_SHUTDOWN", {"node_id": 0})
action = ("node_shutdown", {"node_name": "client_1"})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
action = ("DONOTHING", {"node_id": 0})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF
# turn it on
action = ("NODE_STARTUP", {"node_id": 0})
action = ("node_startup", {"node_name": "client_1"})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.BOOTING
for i in range(client_1.start_up_duration + 1):
action = ("DONOTHING", {"node_id": 0})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
@@ -65,7 +65,7 @@ def test_node_cannot_be_started_up_if_node_is_already_on(game_and_agent_fixture:
assert client_1.operating_state == NodeOperatingState.ON
# turn it on
action = ("NODE_STARTUP", {"node_id": 0})
action = ("node_startup", {"node_name": "client_1"})
agent.store_action(action)
game.step()
@@ -80,14 +80,14 @@ def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture:
client_1.power_off()
for i in range(client_1.shut_down_duration + 1):
action = ("DONOTHING", {"node_id": 0})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF
# turn it ff
action = ("NODE_SHUTDOWN", {"node_id": 0})
action = ("node_shutdown", {"node_name": "client_1"})
agent.store_action(action)
game.step()

View File

@@ -31,7 +31,7 @@ def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
dns_server.pause()
assert dns_server.operating_state == ServiceOperatingState.PAUSED
action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0})
action = ("node_service_start", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.PAUSED
@@ -40,7 +40,7 @@ def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0})
action = ("node_service_start", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
@@ -54,7 +54,7 @@ def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent])
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
dns_server = server_1.software_manager.software.get("DNSServer")
action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0})
action = ("node_service_resume", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.RUNNING
@@ -63,7 +63,7 @@ def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent])
assert dns_server.operating_state == ServiceOperatingState.PAUSED
action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0})
action = ("node_service_resume", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
@@ -80,27 +80,27 @@ def test_service_cannot_perform_actions_unless_running(game_and_agent_fixture: T
dns_server.stop()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_PAUSE", {"node_id": 1, "service_id": 0})
action = ("node_service_pause", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0})
action = ("node_service_resume", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_RESTART", {"node_id": 1, "service_id": 0})
action = ("node_service_restart", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0})
action = ("node_service_fix", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED

View File

@@ -36,9 +36,9 @@ def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"SSH_TO_REMOTE",
"node_session_remote_login",
{
"node_id": 0,
"node_name": "client_1",
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
@@ -68,9 +68,9 @@ def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame,
server_1_usm.add_user("user123", "password", is_admin=True)
action = (
"SSH_TO_REMOTE",
"node_session_remote_login",
{
"node_id": 0,
"node_name": "client_1",
"username": "user123",
"password": "wrong_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
@@ -100,12 +100,13 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame
server_1_um.add_user("user123", "password", is_admin=True)
action = (
"NODE_ACCOUNTS_CHANGE_PASSWORD",
"node_account_change_password",
{
"node_id": 1, # server_1
"node_name": "server_1", # server_1
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
@@ -126,9 +127,9 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
# Log in remotely
action = (
"SSH_TO_REMOTE",
"node_session_remote_login",
{
"node_id": 0,
"node_name": "client_1",
"username": "user123",
"password": "password",
"remote_ip": str(server_1.network_interface[1].ip_address),
@@ -139,12 +140,13 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
# Change password
action = (
"NODE_ACCOUNTS_CHANGE_PASSWORD",
"node_account_change_password",
{
"node_id": 1, # server_1
"node_name": "server_1", # server_1
"username": "user123",
"current_password": "password",
"new_password": "different_password",
"remote_ip": str(server_1.network_interface[1].ip_address),
},
)
agent.store_action(action)
@@ -152,9 +154,9 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
# Assert that the user cannot execute an action
action = (
"NODE_SEND_REMOTE_COMMAND",
"node_send_remote_command",
{
"node_id": 0,
"node_name": "client_1",
"remote_ip": str(server_1.network_interface[1].ip_address),
"command": ["file_system", "create", "file", "folder123", "doggo.pdf", False],
},

View File

@@ -32,11 +32,11 @@ def test_file_observation(simulation):
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
observation_state = dog_file_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # good initial
assert observation_state.get("health_status") == 0 # initially unset
file.corrupt()
observation_state = dog_file_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # scan file so this changes
assert observation_state.get("health_status") == 0 # still default unset value because no scan happened
file.scan()
file.apply_timestep(0) # apply time step
@@ -63,11 +63,11 @@ def test_folder_observation(simulation):
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("FILES") is not None
assert observation_state.get("health_status") == 1
assert observation_state.get("health_status") == 0 # initially unset
file.corrupt() # corrupt just the file
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # scan folder to change this
assert observation_state.get("health_status") == 0 # still unset as no scan occurred yet
folder.scan()
for i in range(folder.scan_duration + 1):

View File

@@ -191,7 +191,7 @@ def test_nic_monitored_traffic(simulation):
# send a database query
browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
browser.target_url = f"http://arcd.com/"
browser.config.target_url = f"http://arcd.com/"
browser.get_webpage()
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")

View File

@@ -13,7 +13,7 @@ DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yam
def env_with_ssh() -> PrimaiteGymEnv:
"""Build data manipulation environment with SSH port open on router."""
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
env.agent.flatten_obs = False
env.agent.config.agent_settings.flatten_obs = False
router: Router = env.game.simulation.network.get_node_by_hostname("router_1")
router.acl.add_rule(ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=3)
return env

View File

@@ -24,12 +24,12 @@ def test_rng_seed_set(create_env):
env.reset(seed=3)
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
env.reset(seed=3)
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
assert a == b
@@ -40,11 +40,11 @@ def test_rng_seed_unset(create_env):
env.reset()
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
env.reset()
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do_nothing"]
assert a != b

View File

@@ -15,7 +15,6 @@ def test_mask_contents_correct():
net = sim.network
mask = game.action_mask("defender")
agent = env.agent
node_list = agent.action_manager.node_names
action_map = agent.action_manager.action_map
# CHECK NIC ENABLE/DISABLE ACTIONS
@@ -23,10 +22,10 @@ def test_mask_contents_correct():
mask = game.action_mask("defender")
act_type, act_params = action
if act_type == "NODE_NIC_ENABLE":
node_name = node_list[act_params["node_id"]]
if act_type == "node_nic_enable":
node_name = act_params["node_name"]
node_obj = net.get_node_by_hostname(node_name)
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
nic_obj = node_obj.network_interface[act_params["nic_num"]]
assert nic_obj.enabled
assert not mask[action_num]
nic_obj.disable()
@@ -34,10 +33,10 @@ def test_mask_contents_correct():
assert mask[action_num]
nic_obj.enable()
if act_type == "NODE_NIC_DISABLE":
node_name = node_list[act_params["node_id"]]
if act_type == "node_nic_disable":
node_name = act_params["node_name"]
node_obj = net.get_node_by_hostname(node_name)
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
nic_obj = node_obj.network_interface[act_params["nic_num"]]
assert nic_obj.enabled
assert mask[action_num]
nic_obj.disable()
@@ -45,14 +44,14 @@ def test_mask_contents_correct():
assert not mask[action_num]
nic_obj.enable()
if act_type == "ROUTER_ACL_ADDRULE":
if act_type == "router_acl_add_rule":
assert mask[action_num]
if act_type == "ROUTER_ACL_REMOVERULE":
if act_type == "router_acl_remove_rule":
assert mask[action_num]
if act_type == "NODE_RESET":
node_name = node_list[act_params["node_id"]]
if act_type == "node_reset":
node_name = act_params["node_name"]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
@@ -61,8 +60,8 @@ def test_mask_contents_correct():
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_SHUTDOWN":
node_name = node_list[act_params["node_id"]]
if act_type == "node_shutdown":
node_name = act_params["node_name"]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
@@ -71,8 +70,8 @@ def test_mask_contents_correct():
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_OS_SCAN":
node_name = node_list[act_params["node_id"]]
if act_type == "node_os_scan":
node_name = act_params["node_name"]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
@@ -81,8 +80,8 @@ def test_mask_contents_correct():
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_STARTUP":
node_name = node_list[act_params["node_id"]]
if act_type == "node_startup":
node_name = act_params["node_name"]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert not mask[action_num]
@@ -91,15 +90,15 @@ def test_mask_contents_correct():
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "DONOTHING":
if act_type == "do_nothing":
assert mask[action_num]
if act_type == "NODE_SERVICE_DISABLE":
if act_type == "node_service_disable":
assert mask[action_num]
if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]:
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
if act_type in ["node_service_scan", "node_service_stop", "node_service_pause"]:
node_name = act_params["node_name"]
service_name = act_params["service_name"]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
@@ -109,9 +108,9 @@ def test_mask_contents_correct():
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_RESUME":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
if act_type == "node_service_resume":
node_name = act_params["node_name"]
service_name = act_params["service_name"]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
@@ -121,9 +120,9 @@ def test_mask_contents_correct():
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_START":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
if act_type == "node_service_start":
node_name = act_params["node_name"]
service_name = act_params["service_name"]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
@@ -133,9 +132,9 @@ def test_mask_contents_correct():
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_ENABLE":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
if act_type == "node_service_enable":
node_name = act_params["node_name"]
service_name = act_params["service_name"]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
@@ -145,12 +144,10 @@ def test_mask_contents_correct():
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]:
node_name = node_list[act_params["node_id"]]
folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"])
file_name = agent.action_manager.get_file_name_by_idx(
act_params["node_id"], act_params["folder_id"], act_params["file_id"]
)
if act_type in ["node_file_scan", "node_file_checkhash", "node_file_delete"]:
node_name = act_params["node_name"]
folder_name = act_params["folder_name"]
file_name = act_params["file_name"]
node_obj = net.get_node_by_hostname(node_name)
file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True)
assert not file_obj.deleted

View File

@@ -32,10 +32,10 @@ FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.
def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the DoNothingAction can form a request and that it is accepted by the simulation."""
"""Test that the do_nothingAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
@@ -56,7 +56,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
assert svc.health_state_visible == SoftwareHealthState.UNUSED
# 2: Scan and check that the visible state is now correct
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.GOOD
@@ -67,7 +67,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy
assert svc.health_state_visible == SoftwareHealthState.GOOD
# 4: Scan and check that the visible state is now correct
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.COMPROMISED
@@ -88,7 +88,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
svc.health_state_actual = SoftwareHealthState.COMPROMISED
# 2: Apply a patch action
action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0})
action = ("node_service_fix", {"node_name": "server_1", "service_name": "DNSServer"})
agent.store_action(action)
game.step()
@@ -96,7 +96,7 @@ def test_node_service_fix_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
assert svc.health_state_actual == SoftwareHealthState.FIXING
# 4: perform a few do-nothing steps and check that the service is now in the good state
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert svc.health_state_actual == SoftwareHealthState.GOOD
@@ -121,18 +121,18 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
# 2: Add a rule to block client 1 from reaching server 2 on router
action = (
"ROUTER_ACL_ADDRULE",
"router_acl_add_rule",
{
"target_router": "router",
"position": 4, # 4th rule
"permission": 2, # DENY
"source_ip_id": 3, # 10.0.1.2 (client_1)
"dest_ip_id": 6, # 10.0.2.3 (server_2)
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
"position": 4,
"permission": "DENY",
"src_ip": "10.0.1.2",
"src_wildcard": "NONE",
"src_port": "ALL",
"dst_ip": "10.0.2.3",
"dst_wildcard": "NONE",
"dst_port": "ALL",
"protocol_name": "icmp",
},
)
agent.store_action(action)
@@ -148,24 +148,26 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
# 4: Add a rule to block server_1 from reaching server_2 on router (this should not affect comms as they are on same subnet)
action = (
"ROUTER_ACL_ADDRULE",
"router_acl_add_rule",
{
"target_router": "router",
"position": 5, # 5th rule
"permission": 2, # DENY
"source_ip_id": 5, # 10.0.2.2 (server_1)
"dest_ip_id": 6, # 10.0.2.3 (server_2)
"dest_port_id": 1, # ALL
"source_port_id": 1, # ALL
"protocol_id": 1, # ALL
"source_wildcard_id": 0,
"dest_wildcard_id": 0,
"permission": "DENY", # DENY
"src_ip": "10.0.2.2", # 10.0.2.2 (server_1)
"src_wildcard": 0,
"src_port": "ALL", # ALL
"dst_ip": "10.0.2.3", # 10.0.2.3 (server_2)
"dst_wildcard": 0,
"dst_port": "ALL", # ALL
"protocol_name": "ALL", # ALL
},
)
agent.store_action(action)
print(agent.most_recent_action)
game.step()
print(agent.most_recent_action)
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
print(router.acl.show())
assert router.acl.num_rules == 6
assert server_1.ping("10.0.2.3") # Can ping server_2
@@ -181,12 +183,12 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Remove rule that allows HTTP traffic across the network
action = (
"ROUTER_ACL_REMOVERULE",
"router_acl_remove_rule",
{
"target_router": "router",
"position": 3, # 4th rule
@@ -214,15 +216,15 @@ def test_host_nic_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
action = (
"HOST_NIC_DISABLE",
"host_nic_disable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -250,10 +252,10 @@ def test_host_nic_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAg
# 2: Use action to enable nic
action = (
"HOST_NIC_ENABLE",
"host_nic_enable",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
"node_name": "client_1", # client_1
"nic_num": 1, # the only nic (eth-1)
},
)
agent.store_action(action)
@@ -273,15 +275,15 @@ def test_node_file_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAge
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file("downloads", "cat.png")
assert file.health_status == FileSystemItemHealthStatus.GOOD
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
assert file.visible_health_status == FileSystemItemHealthStatus.NONE
# 2: perform a scan and make sure nothing has changed
action = (
"NODE_FILE_SCAN",
"node_file_scan",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads,
"file_id": 0, # cat.png
"node_name": "client_1", # client_1,
"folder_name": "downloads", # downloads,
"file_name": "cat.png", # cat.png
},
)
agent.store_action(action)
@@ -314,11 +316,11 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
# 2: delete the file
action = (
"NODE_FILE_DELETE",
"node_file_delete",
{
"node_id": 0, # client_1
"folder_id": 0, # downloads
"file_id": 0, # cat.png
"node_name": "client_1", # client_1
"folder_name": "downloads", # downloads
"file_name": "cat.png", # cat.png
},
)
agent.store_action(action)
@@ -334,14 +336,15 @@ def test_node_file_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a file is created."""
game, agent = game_and_agent
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
client_1 = game.simulation.network.get_node_by_hostname("client_1")
action = (
"NODE_FILE_CREATE",
"node_file_create",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
"force": "False",
},
)
agent.store_action(action)
@@ -357,9 +360,9 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
action = (
"NODE_FILE_CREATE",
"node_file_create",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
},
@@ -370,9 +373,9 @@ def test_node_file_access(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
assert client_1.file_system.get_file(folder_name="test", file_name="file.txt").num_access == 0
action = (
"NODE_FILE_ACCESS",
"node_file_access",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
"file_name": "file.txt",
},
@@ -390,9 +393,9 @@ def test_node_folder_create(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
client_1 = game.simulation.network.get_node_by_hostname("client_1") #
action = (
"NODE_FOLDER_CREATE",
"node_folder_create",
{
"node_id": 0,
"node_name": "client_1",
"folder_name": "test",
},
)
@@ -413,15 +416,15 @@ def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteG
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
action = (
"NETWORK_PORT_DISABLE",
"network_port_disable",
{
"target_nodename": "router", # router
"port_id": 1, # port 1
"port_num": 1, # port 1
},
)
agent.store_action(action)
@@ -450,10 +453,10 @@ def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGa
# 2: Use action to enable port
action = (
"NETWORK_PORT_ENABLE",
"network_port_enable",
{
"target_nodename": "router", # router
"port_id": 1, # port 1
"port_num": 1, # port 1
},
)
agent.store_action(action)
@@ -473,14 +476,17 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com
assert browser.health_state_actual == SoftwareHealthState.GOOD
assert browser.health_state_visible == SoftwareHealthState.UNUSED
# 2: Scan and check that the visible state is now correct
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = (
"node_application_scan",
{"node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.GOOD
@@ -491,7 +497,10 @@ def test_node_application_scan_integration(game_and_agent: Tuple[PrimaiteGame, P
assert browser.health_state_visible == SoftwareHealthState.GOOD
# 4: Scan and check that the visible state is now correct
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
action = (
"node_application_scan",
{"node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.COMPROMISED
@@ -512,7 +521,10 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
browser.health_state_actual = SoftwareHealthState.COMPROMISED
# 2: Apply a fix action
action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0})
action = (
"node_application_fix",
{"node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
@@ -520,7 +532,7 @@ def test_node_application_fix_integration(game_and_agent: Tuple[PrimaiteGame, Pr
assert browser.health_state_actual == SoftwareHealthState.FIXING
# 4: perform a few do-nothing steps and check that the application is now in the good state
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert browser.health_state_actual == SoftwareHealthState.GOOD
@@ -538,7 +550,10 @@ def test_node_application_close_integration(game_and_agent: Tuple[PrimaiteGame,
assert browser.operating_state == ApplicationOperatingState.RUNNING
# 2: Apply a close action
action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0})
action = (
"node_application_close",
{"node_name": "client_1", "application_name": "WebBrowser"},
)
agent.store_action(action)
game.step()
@@ -549,7 +564,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
"""Test that the NodeApplicationInstallAction and NodeApplicationRemoveAction can form a request and that
it is accepted by the simulation.
When you initiate a install action, the Application will be installed and configured on the node.
When you initiate an install action, the Application will be installed and configured on the node.
The remove action will uninstall the application from the node."""
game, agent = game_and_agent
@@ -557,13 +572,19 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
assert client_1.software_manager.software.get("DoSBot") is None
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
action = (
"node_application_install",
{"node_name": "client_1", "application_name": "DoSBot"},
)
agent.store_action(action)
game.step()
assert client_1.software_manager.software.get("DoSBot") is not None
action = ("NODE_APPLICATION_REMOVE", {"node_id": 0, "application_name": "DoSBot"})
action = (
"node_application_remove",
{"node_name": "client_1", "application_name": "DoSBot"},
)
agent.store_action(action)
game.step()
@@ -656,9 +677,9 @@ def test_firewall_acl_add_remove_rule_integration():
assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
assert firewall.external_outbound_acl.acl[1].dst_port is None
assert firewall.external_outbound_acl.acl[1].src_port is None
assert firewall.external_outbound_acl.acl[1].protocol is None
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
env.step(12) # Remove ACL rule from External Outbound
assert firewall.external_outbound_acl.num_rules == 1

View File

@@ -17,12 +17,7 @@ def test_file_observation():
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
file_system_requires_scan=True,
file_system_requires_scan=False,
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
# TODO:
# def test_file_num_access():
# ...

View File

@@ -29,16 +29,16 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
browser.config.target_url = "http://www.example.com"
agent.reward_function.register_component(comp, 0.7)
# Check that before trying to fetch the webpage, the reward is 0.0
agent.store_action(("DONOTHING", {}))
agent.store_action(("do_nothing", {}))
game.step()
assert agent.reward_function.current_reward == 0.0
# Check that successfully fetching the webpage yields a reward of 0.7
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
agent.store_action(("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}))
game.step()
assert agent.reward_function.current_reward == 0.7
@@ -50,7 +50,7 @@ def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, Controlle
src_port=PORT_LOOKUP["HTTP"],
dst_port=PORT_LOOKUP["HTTP"],
)
agent.store_action(("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}))
agent.store_action(("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}))
game.step()
assert agent.reward_function.current_reward == -0.7
@@ -83,7 +83,7 @@ def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
response = game.simulation.apply_request(request)
state = game.get_sim_state()
ahi = AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
timestep=0, action="node_application_execute", parameters={}, request=request, response=response
)
reward_value = comp.calculate(state, last_action_response=ahi)
assert reward_value == 1.0
@@ -94,7 +94,7 @@ def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
response = game.simulation.apply_request(request)
state = game.get_sim_state()
ahi = AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=request, response=response
timestep=0, action="node_application_execute", parameters={}, request=request, response=response
)
reward_value = comp.calculate(
state,
@@ -154,13 +154,13 @@ def test_action_penalty():
# Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
Penalty = ActionPenalty(config=schema)
# Assert that penalty is applied if action isn't DONOTHING
# Assert that penalty is applied if action isn't do_nothing
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="NODE_APPLICATION_EXECUTE",
parameters={"node_id": 0, "application_id": 1},
action="node_application_execute",
parameters={"node_name": "client", "application_name": "WebBrowser"},
request=["execute"],
response=RequestResponse.from_bool(True),
),
@@ -168,12 +168,12 @@ def test_action_penalty():
assert reward_value == -0.75
# Assert that no penalty applied for a DONOTHING action
# Assert that no penalty applied for a do_nothing action
reward_value = Penalty.calculate(
state={},
last_action_response=AgentHistoryItem(
timestep=0,
action="DONOTHING",
action="do_nothing",
parameters={},
request=["do_nothing"],
response=RequestResponse.from_bool(True),
@@ -192,12 +192,12 @@ def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]
agent.reward_function.register_component(comp, 1.0)
action = ("DONOTHING", {})
action = ("do_nothing", {})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == 0.125
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
action = ("node_file_scan", {"node_name": "client", "folder_name": "downloads", "file_name": "document.pdf"})
agent.store_action(action)
game.step()
assert agent.reward_function.current_reward == -0.75

View File

@@ -3,6 +3,7 @@ from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Tuple
import pytest
from pydantic import Field
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -14,9 +15,16 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
from primaite.utils.validation.port import PORT_LOOKUP
class BroadcastTestService(Service):
class BroadcastTestService(Service, identifier="BroadcastTestService"):
"""A service for sending broadcast and unicast messages over a network."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestService."""
type: str = "BroadcastTestService"
config: "BroadcastTestService.ConfigSchema" = Field(default_factory=lambda: BroadcastTestService.ConfigSchema())
def __init__(self, **kwargs):
# Set default service properties for broadcasting
kwargs["name"] = "BroadcastService"
@@ -46,6 +54,13 @@ class BroadcastTestService(Service):
class BroadcastTestClient(Application, identifier="BroadcastTestClient"):
"""A client application to receive broadcast and unicast messages."""
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for BroadcastTestClient."""
type: str = "BroadcastTestClient"
config: ConfigSchema = Field(default_factory=lambda: BroadcastTestClient.ConfigSchema())
payloads_received: List = []
def __init__(self, **kwargs):

View File

@@ -495,6 +495,12 @@ def test_c2_suite_yaml():
computer_b: Computer = yaml_network.get_node_by_hostname("node_b")
c2_beacon: C2Beacon = computer_b.software_manager.software.get("C2Beacon")
c2_beacon.configure(
c2_server_ip_address=c2_beacon.config.c2_server_ip_address,
keep_alive_frequency=c2_beacon.config.keep_alive_frequency,
masquerade_port=c2_beacon.config.masquerade_port,
masquerade_protocol=c2_beacon.config.masquerade_protocol,
)
assert c2_server.operating_state == ApplicationOperatingState.RUNNING

View File

@@ -163,7 +163,7 @@ def test_restore_backup_without_updating_scan(uc2_network):
db_service.db_file.corrupt() # corrupt the db
assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT # db file is actually corrupt
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # not scanned yet
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.NONE # not scanned yet
db_service.db_file.scan() # scan the db file
@@ -190,7 +190,7 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network):
db_service.db_file.corrupt() # corrupt the db
assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT # db file is actually corrupt
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # not scanned yet
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.NONE # not scanned yet
db_service.db_file.scan() # scan the db file
@@ -232,7 +232,7 @@ def test_database_service_fix(uc2_network):
assert db_service.health_state_actual == SoftwareHealthState.FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
@@ -266,7 +266,7 @@ def test_database_cannot_be_queried_while_fixing(uc2_network):
assert db_connection.query(sql="SELECT") is False
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD
@@ -308,7 +308,7 @@ def test_database_can_create_connection_while_fixing(uc2_network):
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
# apply timestep until the fix is applied
for i in range(db_service.fixing_duration + 1):
for i in range(db_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert db_service.health_state_actual == SoftwareHealthState.GOOD

View File

@@ -14,7 +14,14 @@ from primaite.utils.validation.port import PORT_LOOKUP
from tests import TEST_ASSETS_ROOT
class _DatabaseListener(Service):
class _DatabaseListener(Service, identifier="_DatabaseListener"):
class ConfigSchema(Service.ConfigSchema):
"""ConfigSchema for _DatabaseListener."""
type: str = "_DatabaseListener"
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
name: str = "DatabaseListener"
protocol: str = PROTOCOL_LOOKUP["TCP"]
port: int = PORT_LOOKUP["NONE"]

View File

@@ -51,7 +51,7 @@ def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_ser
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -66,7 +66,7 @@ def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_serv
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://{web_server_ip}/"
web_browser_app.config.target_url = f"http://{web_server_ip}/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -81,7 +81,7 @@ def test_web_page_request_from_shut_down_server(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.network_interfaces.get(next(iter(server.network_interfaces))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
@@ -108,7 +108,7 @@ def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
web_browser_app.target_url = f"http://arcd.com/"
web_browser_app.config.target_url = f"http://arcd.com/"
assert web_browser_app.get_webpage() is True
# latest response should have status code 200

View File

@@ -74,7 +74,7 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer,
# Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
web_browser.config.target_url = "http://arcd.com/users/"
web_browser.run()
# Install DNS Client service on computer
@@ -131,7 +131,7 @@ def test_database_fix_disrupts_web_client(uc2_network):
assert web_browser.get_webpage() is False
for i in range(database_service.fixing_duration + 1):
for i in range(database_service.config.fixing_duration + 1):
uc2_network.apply_timestep(i)
assert database_service.health_state_actual == SoftwareHealthState.GOOD