Merge branch 'dev' into bugfix/2676_NMNE_var_access
This commit is contained in:
@@ -9,9 +9,10 @@ from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
@@ -85,7 +86,7 @@ def test_node_software_install():
|
||||
assert client_2.software_manager.software.get(software.__name__) is not None
|
||||
|
||||
# check that applications have been installed on client 1
|
||||
for applications in APPLICATION_TYPES_MAPPING:
|
||||
for applications in Application._application_registry:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
|
||||
# check that services have been installed on client 1
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import copy
|
||||
from ipaddress import IPv4Address
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
|
||||
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_default_fix_duration():
|
||||
"""Test that software with no defined fix duration in config uses the default fix duration of 2."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_2: Computer = game.simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
database_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 2
|
||||
|
||||
dns_client: DNSClient = client_2.software_manager.software.get("DNSClient")
|
||||
assert dns_client.fixing_duration == 2
|
||||
|
||||
|
||||
def test_fix_duration_set_from_config():
|
||||
"""Test to check that the fix duration set for applications and services works as intended."""
|
||||
game = load_config(TEST_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
for service in SERVICE_TYPES_MAPPING:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 3
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
for applications in APPLICATION_TYPES_MAPPING:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
assert client_1.software_manager.software.get(applications).fixing_duration == 1
|
||||
|
||||
|
||||
def test_fix_duration_for_one_item():
|
||||
"""Test that setting fix duration for one application does not affect other components."""
|
||||
game = load_config(ONE_ITEM_CONFIG)
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
# in config - services take 3 timesteps to fix
|
||||
services = copy.copy(SERVICE_TYPES_MAPPING)
|
||||
services.pop("DatabaseService")
|
||||
for service in services:
|
||||
assert client_1.software_manager.software.get(service) is not None
|
||||
assert client_1.software_manager.software.get(service).fixing_duration == 2
|
||||
|
||||
# in config - applications take 1 timestep to fix
|
||||
applications = copy.copy(APPLICATION_TYPES_MAPPING)
|
||||
applications.pop("DatabaseClient")
|
||||
for applications in applications:
|
||||
assert client_1.software_manager.software.get(applications) is not None
|
||||
assert client_1.software_manager.software.get(applications).fixing_duration == 2
|
||||
|
||||
database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
assert database_client.fixing_duration == 1
|
||||
|
||||
database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService")
|
||||
assert database_service.fixing_duration == 5
|
||||
1
tests/integration_tests/game_layer/actions/__init__.py
Normal file
1
tests/integration_tests/game_layer/actions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
@@ -0,0 +1,292 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
ConfigureDatabaseClientAction,
|
||||
ConfigureDoSBotAction,
|
||||
ConfigureRansomwareScriptAction,
|
||||
)
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import ControlledAgent
|
||||
|
||||
APP_CONFIG_YAML = TEST_ASSETS_ROOT / "configs/install_and_configure_apps.yaml"
|
||||
|
||||
|
||||
class TestConfigureDatabaseAction:
|
||||
def test_configure_ip_password(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"server_ip_address": "192.168.1.99",
|
||||
"server_password": "admin123",
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == IPv4Address("192.168.1.99")
|
||||
assert db_client.server_password == "admin123"
|
||||
|
||||
def test_configure_ip(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"server_ip_address": "192.168.1.99",
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == IPv4Address("192.168.1.99")
|
||||
assert db_client.server_password is None
|
||||
|
||||
def test_configure_password(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"]
|
||||
old_ip = db_client.server_ip_address
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DATABASE_CLIENT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"server_password": "admin123",
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert db_client.server_ip_address == old_ip
|
||||
assert db_client.server_password is "admin123"
|
||||
|
||||
|
||||
class TestConfigureRansomwareScriptAction:
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{},
|
||||
{"server_ip_address": "181.181.181.181"},
|
||||
{"server_password": "admin123"},
|
||||
{"payload": "ENCRYPT"},
|
||||
{
|
||||
"server_ip_address": "181.181.181.181",
|
||||
"server_password": "admin123",
|
||||
"payload": "ENCRYPT",
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_configure_ip_password(self, game_and_agent, config):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
|
||||
agent.action_manager
|
||||
)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(RansomwareScript)
|
||||
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
|
||||
|
||||
old_ip = ransomware_script.server_ip_address
|
||||
old_pw = ransomware_script.server_password
|
||||
old_payload = ransomware_script.payload
|
||||
|
||||
action = (
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT",
|
||||
{"node_id": 0, "config": config},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
expected_ip = old_ip if "server_ip_address" not in config else IPv4Address(config["server_ip_address"])
|
||||
expected_pw = old_pw if "server_password" not in config else config["server_password"]
|
||||
expected_payload = old_payload if "payload" not in config else config["payload"]
|
||||
|
||||
assert ransomware_script.server_ip_address == expected_ip
|
||||
assert ransomware_script.server_password == expected_pw
|
||||
assert ransomware_script.payload == expected_payload
|
||||
|
||||
def test_invalid_config(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_RANSOMWARE_SCRIPT"] = ConfigureRansomwareScriptAction(
|
||||
agent.action_manager
|
||||
)
|
||||
|
||||
# make sure there is a database client on this node
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(RansomwareScript)
|
||||
ransomware_script: RansomwareScript = client_1.software_manager.software["RansomwareScript"]
|
||||
action = (
|
||||
"CONFIGURE_RANSOMWARE_SCRIPT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {"server_password": "admin123", "bad_option": 70},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
with pytest.raises(ValidationError):
|
||||
game.step()
|
||||
|
||||
|
||||
class TestConfigureDoSBot:
|
||||
def test_configure_DoSBot(self, game_and_agent):
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
agent.action_manager.actions["CONFIGURE_DOSBOT"] = ConfigureDoSBotAction(agent.action_manager)
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(DoSBot)
|
||||
dos_bot: DoSBot = client_1.software_manager.software["DoSBot"]
|
||||
|
||||
action = (
|
||||
"CONFIGURE_DOSBOT",
|
||||
{
|
||||
"node_id": 0,
|
||||
"config": {
|
||||
"target_ip_address": "192.168.1.99",
|
||||
"target_port": "POSTGRES_SERVER",
|
||||
"payload": "HACC",
|
||||
"repeat": False,
|
||||
"port_scan_p_of_success": 0.875,
|
||||
"dos_intensity": 0.75,
|
||||
"max_sessions": 50,
|
||||
},
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert dos_bot.target_ip_address == IPv4Address("192.168.1.99")
|
||||
assert dos_bot.target_port == Port.POSTGRES_SERVER
|
||||
assert dos_bot.payload == "HACC"
|
||||
assert not dos_bot.repeat
|
||||
assert dos_bot.port_scan_p_of_success == 0.875
|
||||
assert dos_bot.dos_intensity == 0.75
|
||||
assert dos_bot.max_sessions == 50
|
||||
|
||||
|
||||
class TestConfigureYAML:
|
||||
def test_configure_db_client(self):
|
||||
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
|
||||
|
||||
# make sure there's no db client on the node yet
|
||||
client_1 = env.game.simulation.network.get_node_by_hostname("client_1")
|
||||
assert client_1.software_manager.software.get("DatabaseClient") is None
|
||||
|
||||
# take the install action, check that the db gets installed, step to get it to finish installing
|
||||
env.step(1)
|
||||
db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
|
||||
assert isinstance(db_client, DatabaseClient)
|
||||
assert db_client.operating_state == ApplicationOperatingState.INSTALLING
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# configure the ip address and check that it changes, but password doesn't change
|
||||
assert db_client.server_ip_address is None
|
||||
assert db_client.server_password is None
|
||||
env.step(4)
|
||||
assert db_client.server_ip_address == IPv4Address("10.0.0.5")
|
||||
assert db_client.server_password is None
|
||||
|
||||
# configure the password and check that it changes, make sure this lets us connect to the db
|
||||
assert not db_client.connect()
|
||||
env.step(5)
|
||||
assert db_client.server_password == "correct_password"
|
||||
assert db_client.connect()
|
||||
|
||||
def test_configure_ransomware_script(self):
|
||||
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
|
||||
client_2 = env.game.simulation.network.get_node_by_hostname("client_2")
|
||||
assert client_2.software_manager.software.get("RansomwareScript") is None
|
||||
|
||||
# install ransomware script
|
||||
env.step(2)
|
||||
ransom = client_2.software_manager.software.get("RansomwareScript")
|
||||
assert isinstance(ransom, RansomwareScript)
|
||||
assert ransom.operating_state == ApplicationOperatingState.INSTALLING
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# make sure it's not working yet because it's not configured and there's no db client
|
||||
assert not ransom.attack()
|
||||
env.step(8) # install db client on the same node
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0) # let it finish installing
|
||||
assert not ransom.attack()
|
||||
|
||||
# finally, configure the ransomware script with ip and password
|
||||
env.step(6)
|
||||
assert ransom.attack()
|
||||
|
||||
db_server = env.game.simulation.network.get_node_by_hostname("server_1")
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
def test_configure_dos_bot(self):
|
||||
env = PrimaiteGymEnv(env_config=APP_CONFIG_YAML)
|
||||
client_3 = env.game.simulation.network.get_node_by_hostname("client_3")
|
||||
assert client_3.software_manager.software.get("DoSBot") is None
|
||||
|
||||
# install DoSBot
|
||||
env.step(3)
|
||||
bot = client_3.software_manager.software.get("DoSBot")
|
||||
assert isinstance(bot, DoSBot)
|
||||
assert bot.operating_state == ApplicationOperatingState.INSTALLING
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# make sure dos bot doesn't work before being configured
|
||||
assert not bot.run()
|
||||
env.step(7)
|
||||
assert bot.run()
|
||||
@@ -557,7 +557,7 @@ def test_node_application_install_and_uninstall_integration(game_and_agent: Tupl
|
||||
|
||||
assert client_1.software_manager.software.get("DoSBot") is None
|
||||
|
||||
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot", "ip_address": "192.168.1.14"})
|
||||
action = ("NODE_APPLICATION_INSTALL", {"node_id": 0, "application_name": "DoSBot"})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.agent.rewards import ActionPenalty, GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
@@ -119,3 +121,77 @@ def test_shared_reward():
|
||||
g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward
|
||||
blue_reward = env.game.agents["defender"].reward_function.current_reward
|
||||
assert blue_reward == g1_reward + g2_reward
|
||||
|
||||
|
||||
def test_action_penalty_loads_from_config():
|
||||
"""Test to ensure that action penalty is correctly loaded from config into PrimaiteGymEnv"""
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/action_penalty.yaml"
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
env.reset()
|
||||
defender = env.game.agents["defender"]
|
||||
act_penalty_obj = None
|
||||
for comp in defender.reward_function.reward_components:
|
||||
if isinstance(comp[0], ActionPenalty):
|
||||
act_penalty_obj = comp[0]
|
||||
if act_penalty_obj is None:
|
||||
pytest.fail("Action penalty reward component was not added to the agent from config.")
|
||||
assert act_penalty_obj.action_penalty == -0.75
|
||||
assert act_penalty_obj.do_nothing_penalty == 0.125
|
||||
|
||||
|
||||
def test_action_penalty():
|
||||
"""Test that the action penalty is correctly applied when agent performs any action"""
|
||||
|
||||
# Create an ActionPenalty Reward
|
||||
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
# Assert that penalty is applied if action isn't DONOTHING
|
||||
reward_value = Penalty.calculate(
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="NODE_APPLICATION_EXECUTE",
|
||||
parameters={"node_id": 0, "application_id": 1},
|
||||
request=["execute"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == -0.75
|
||||
|
||||
# Assert that no penalty applied for a DONOTHING action
|
||||
reward_value = Penalty.calculate(
|
||||
state={},
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0,
|
||||
action="DONOTHING",
|
||||
parameters={},
|
||||
request=["do_nothing"],
|
||||
response=RequestResponse.from_bool(True),
|
||||
),
|
||||
)
|
||||
|
||||
assert reward_value == 0.125
|
||||
|
||||
|
||||
def test_action_penalty_e2e(game_and_agent):
|
||||
"""Test that we get the right reward for doing actions to fetch a website."""
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
|
||||
agent.reward_function.register_component(comp, 1.0)
|
||||
|
||||
action = ("DONOTHING", {})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == 0.125
|
||||
|
||||
action = ("NODE_FILE_SCAN", {"node_id": 0, "folder_id": 0, "file_id": 0})
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.reward_function.current_reward == -0.75
|
||||
|
||||
@@ -41,7 +41,7 @@ class BroadcastService(Service):
|
||||
super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol)
|
||||
|
||||
|
||||
class BroadcastClient(Application):
|
||||
class BroadcastClient(Application, identifier="BroadcastClient"):
|
||||
"""A client application to receive broadcast and unicast messages."""
|
||||
|
||||
payloads_received: List = []
|
||||
|
||||
@@ -3,9 +3,9 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.networks import multi_lan_internet_network_example
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from src.primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
|
||||
|
||||
def test_all_with_configured_dns_server_ip_can_resolve_url():
|
||||
|
||||
@@ -21,7 +21,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
|
||||
computer.power_on()
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app = computer.software_manager.software.get("TestApplication")
|
||||
app = computer.software_manager.software.get("DummyApplication")
|
||||
app.run()
|
||||
|
||||
return app, computer
|
||||
@@ -39,7 +39,7 @@ def test_application_on_offline_node(application_class):
|
||||
)
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app: Application = computer.software_manager.software.get("TestApplication")
|
||||
app: Application = computer.software_manager.software.get("DummyApplication")
|
||||
|
||||
computer.power_off()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -213,6 +214,110 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network):
|
||||
assert db_service.db_file.visible_health_status == FileSystemItemHealthStatus.GOOD # now looks good
|
||||
|
||||
|
||||
def test_database_service_fix(uc2_network):
|
||||
"""Test that the software fix applies to database service."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
# delete database locally
|
||||
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# db file is gone, reduced to atoms
|
||||
assert db_service.db_file is None
|
||||
|
||||
db_service.fix() # fix the database service
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
|
||||
def test_database_cannot_be_queried_while_fixing(uc2_network):
|
||||
"""Tests that the database service cannot be queried if the service is being fixed."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
# delete database locally
|
||||
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# db file is gone, reduced to atoms
|
||||
assert db_service.db_file is None
|
||||
|
||||
db_service.fix() # fix the database service
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# fails to query because database is in FIXING state
|
||||
assert db_connection.query(sql="SELECT") is False
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
|
||||
|
||||
def test_database_can_create_connection_while_fixing(uc2_network):
|
||||
"""Tests that connections cannot be created while the database is being fixed."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
client_2: Server = uc2_network.get_node_by_hostname("client_2")
|
||||
db_client: DatabaseClient = client_2.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
# delete database locally
|
||||
db_service.file_system.delete_file(folder_name="database", file_name="database.db")
|
||||
|
||||
# db file is gone, reduced to atoms
|
||||
assert db_service.db_file is None
|
||||
|
||||
db_service.fix() # fix the database service
|
||||
assert db_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
# fails to query because database is in FIXING state
|
||||
assert db_connection.query(sql="SELECT") is False
|
||||
|
||||
# should be able to create a new connection
|
||||
new_db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
assert new_db_connection is not None
|
||||
assert new_db_connection.query(sql="SELECT") is False # still should fail to query because FIXING
|
||||
|
||||
# apply timestep until the fix is applied
|
||||
for i in range(db_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert db_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
assert db_service.db_file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
assert new_db_connection.query(sql="SELECT")
|
||||
|
||||
|
||||
def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 404 when db server is offline."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
|
||||
@@ -16,6 +16,7 @@ from primaite.simulator.system.services.database.database_service import Databas
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -110,6 +111,29 @@ def test_web_client_requests_users(web_client_web_server_database):
|
||||
assert web_browser.get_webpage()
|
||||
|
||||
|
||||
def test_database_fix_disrupts_web_client(uc2_network):
|
||||
"""Tests that the database service being in fixed state disrupts the web client."""
|
||||
computer: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
database_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
# fix the database service
|
||||
database_service.fix()
|
||||
|
||||
assert database_service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
assert web_browser.get_webpage() is False
|
||||
|
||||
for i in range(database_service.fixing_duration + 1):
|
||||
uc2_network.apply_timestep(i)
|
||||
|
||||
assert database_service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
assert web_browser.get_webpage()
|
||||
|
||||
|
||||
class TestWebBrowserHistory:
|
||||
def test_populating_history(self, web_client_web_server_database):
|
||||
network, computer, _, _ = web_client_web_server_database
|
||||
|
||||
@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from tests.conftest import TestApplication, TestService
|
||||
from tests.conftest import DummyApplication, TestService
|
||||
|
||||
|
||||
def test_successful_node_file_system_creation_request(example_network):
|
||||
@@ -47,14 +47,14 @@ def test_successful_application_requests(example_network):
|
||||
net = example_network
|
||||
|
||||
client_1 = net.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.install(TestApplication)
|
||||
client_1.software_manager.software.get("TestApplication").run()
|
||||
client_1.software_manager.install(DummyApplication)
|
||||
client_1.software_manager.software.get("DummyApplication").run()
|
||||
|
||||
resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"])
|
||||
resp_1 = net.apply_request(["node", "client_1", "application", "DummyApplication", "scan"])
|
||||
assert resp_1 == RequestResponse(status="success", data={})
|
||||
resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "fix"])
|
||||
resp_2 = net.apply_request(["node", "client_1", "application", "DummyApplication", "fix"])
|
||||
assert resp_2 == RequestResponse(status="success", data={})
|
||||
resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"])
|
||||
resp_3 = net.apply_request(["node", "client_1", "application", "DummyApplication", "compromise"])
|
||||
assert resp_3 == RequestResponse(status="success", data={})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user