#2887 - Resolve conflicts from merge
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from primaite.game.agent.actions import (
|
||||
ActionManager,
|
||||
DoNothingAction,
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.actions.manager import DoNothingAction
|
||||
from primaite.game.agent.actions.service import (
|
||||
NodeServiceDisableAction,
|
||||
NodeServiceEnableAction,
|
||||
NodeServicePauseAction,
|
||||
@@ -18,13 +19,8 @@ from primaite.game.agent.actions import (
|
||||
|
||||
|
||||
def test_do_nothing_action_form_request():
|
||||
"""Test that the DoNothingAction can form a request and that it is correct."""
|
||||
manager = Mock()
|
||||
|
||||
action = DoNothingAction(manager=manager)
|
||||
|
||||
request = action.form_request()
|
||||
|
||||
"""Test that the do_nothingAction can form a request and that it is correct."""
|
||||
request = DoNothingAction.form_request(DoNothingAction.ConfigSchema())
|
||||
assert request == ["do_nothing"]
|
||||
|
||||
|
||||
@@ -42,7 +38,7 @@ def test_do_nothing_action_form_request():
|
||||
],
|
||||
) # flake8: noqa
|
||||
@pytest.mark.parametrize(
|
||||
"node_name, service_name, expect_to_do_nothing",
|
||||
"node_name, service_name, expect_failure",
|
||||
[
|
||||
("pc_1", "chrome", False),
|
||||
(None, "chrome", True),
|
||||
@@ -50,42 +46,15 @@ def test_do_nothing_action_form_request():
|
||||
(None, None, True),
|
||||
],
|
||||
) # flake8: noqa
|
||||
def test_service_action_form_request(node_name, service_name, expect_to_do_nothing, action_class, action_verb):
|
||||
def test_service_action_form_request(node_name, service_name, expect_failure, action_class, action_verb):
|
||||
"""Test that the ServiceScanAction can form a request and that it is correct."""
|
||||
manager: ActionManager = Mock()
|
||||
manager.get_node_name_by_idx.return_value = node_name
|
||||
manager.get_service_name_by_idx.return_value = service_name
|
||||
|
||||
action = action_class(manager=manager, num_nodes=1, num_services=1)
|
||||
|
||||
request = action.form_request(node_id=0, service_id=0)
|
||||
|
||||
if expect_to_do_nothing:
|
||||
assert request == ["do_nothing"]
|
||||
if expect_failure:
|
||||
with pytest.raises(ValidationError):
|
||||
request = action_class.form_request(
|
||||
config=action_class.ConfigSchema(node_name=node_name, service_name=service_name)
|
||||
)
|
||||
else:
|
||||
request = action_class.form_request(
|
||||
config=action_class.ConfigSchema(node_name=node_name, service_name=service_name)
|
||||
)
|
||||
assert request == ["network", "node", node_name, "service", service_name, action_verb]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"node_name, service_name, expect_to_do_nothing",
|
||||
[
|
||||
("pc_1", "chrome", False),
|
||||
(None, "chrome", True),
|
||||
("pc_1", None, True),
|
||||
(None, None, True),
|
||||
],
|
||||
) # flake8: noqa
|
||||
def test_service_scan_form_request(node_name, service_name, expect_to_do_nothing):
|
||||
"""Test that the ServiceScanAction can form a request and that it is correct."""
|
||||
manager: ActionManager = Mock()
|
||||
manager.get_node_name_by_idx.return_value = node_name
|
||||
manager.get_service_name_by_idx.return_value = service_name
|
||||
|
||||
action = NodeServiceScanAction(manager=manager, num_nodes=1, num_services=1)
|
||||
|
||||
request = action.form_request(node_id=0, service_id=0)
|
||||
|
||||
if expect_to_do_nothing:
|
||||
assert request == ["do_nothing"]
|
||||
else:
|
||||
assert request == ["network", "node", node_name, "service", service_name, "scan"]
|
||||
|
||||
52
tests/unit_tests/_primaite/_game/_agent/test_agent.py
Normal file
52
tests/unit_tests/_primaite/_game/_agent/test_agent.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation
|
||||
from primaite.game.agent.observations.observation_manager import NullObservation
|
||||
from primaite.game.agent.scripted_agents.random_agent import RandomAgent
|
||||
|
||||
|
||||
def test_creating_empty_agent():
|
||||
agent = RandomAgent()
|
||||
assert len(agent.action_manager.action_map) == 0
|
||||
assert isinstance(agent.observation_manager.obs, NullObservation)
|
||||
assert len(agent.reward_function.reward_components) == 0
|
||||
|
||||
|
||||
def test_creating_agent_from_dict():
|
||||
action_config = {
|
||||
"action_map": {
|
||||
0: {"action": "do_nothing", "options": {}},
|
||||
1: {
|
||||
"action": "node_application_execute",
|
||||
"options": {"node_name": "client", "application_name": "database"},
|
||||
},
|
||||
}
|
||||
}
|
||||
observation_config = {
|
||||
"type": "FILE",
|
||||
"options": {
|
||||
"file_name": "dog.pdf",
|
||||
"include_num_access": False,
|
||||
"file_system_requires_scan": False,
|
||||
},
|
||||
}
|
||||
reward_config = {
|
||||
"reward_components": [
|
||||
{
|
||||
"type": "DATABASE_FILE_INTEGRITY",
|
||||
"weight": 0.3,
|
||||
"options": {"node_hostname": "server", "folder_name": "database", "file_name": "database.db"},
|
||||
}
|
||||
]
|
||||
}
|
||||
agent = RandomAgent(
|
||||
config={
|
||||
"ref": "random_agent",
|
||||
"team": "BLUE",
|
||||
"action_space": action_config,
|
||||
"observation_space": observation_config,
|
||||
"reward_function": reward_config,
|
||||
}
|
||||
)
|
||||
|
||||
assert len(agent.action_manager.action_map) == 2
|
||||
assert isinstance(agent.observation_manager.obs, FileObservation)
|
||||
assert len(agent.reward_function.reward_components) == 1
|
||||
@@ -69,8 +69,8 @@ class TestFileSystemRequiresScan:
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
@@ -98,7 +98,7 @@ class TestFileSystemRequiresScan:
|
||||
"""
|
||||
|
||||
cfg = yaml.safe_load(obs_cfg_yaml)
|
||||
manager = ObservationManager.from_config(cfg)
|
||||
manager = ObservationManager(config=cfg)
|
||||
|
||||
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
|
||||
for i, host in enumerate(hosts):
|
||||
@@ -119,14 +119,20 @@ class TestFileSystemRequiresScan:
|
||||
assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3
|
||||
|
||||
def test_folder_require_scan(self):
|
||||
folder_state = {"health_status": 3, "visible_status": 1}
|
||||
folder_state = {"health_status": 3, "visible_status": 1, "scanned_this_step": False}
|
||||
|
||||
obs_requiring_scan = FolderObservation(
|
||||
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
|
||||
)
|
||||
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
|
||||
assert obs_requiring_scan.observe(folder_state)["health_status"] == 0
|
||||
|
||||
obs_not_requiring_scan = FolderObservation(
|
||||
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False
|
||||
)
|
||||
assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3
|
||||
|
||||
folder_state = {"health_status": 3, "visible_status": 1, "scanned_this_step": True}
|
||||
obs_requiring_scan = FolderObservation(
|
||||
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
|
||||
)
|
||||
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
|
||||
|
||||
@@ -3,6 +3,7 @@ from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import PrimaiteGame, PrimaiteGameOptions
|
||||
|
||||
|
||||
def test_probabilistic_agent():
|
||||
@@ -16,69 +17,58 @@ def test_probabilistic_agent():
|
||||
"""
|
||||
N_TRIALS = 10_000
|
||||
P_DO_NOTHING = 0.1
|
||||
P_NODE_APPLICATION_EXECUTE = 0.3
|
||||
P_NODE_FILE_DELETE = 0.6
|
||||
P_node_application_execute = 0.3
|
||||
P_node_file_delete = 0.6
|
||||
MIN_DO_NOTHING = 850
|
||||
MAX_DO_NOTHING = 1150
|
||||
MIN_NODE_APPLICATION_EXECUTE = 2800
|
||||
MAX_NODE_APPLICATION_EXECUTE = 3200
|
||||
MIN_NODE_FILE_DELETE = 5750
|
||||
MAX_NODE_FILE_DELETE = 6250
|
||||
MIN_node_application_execute = 2800
|
||||
MAX_node_application_execute = 3200
|
||||
MIN_node_file_delete = 5750
|
||||
MAX_node_file_delete = 6250
|
||||
|
||||
action_space = ActionManager(
|
||||
actions=[
|
||||
{"type": "DONOTHING"},
|
||||
{"type": "NODE_APPLICATION_EXECUTE"},
|
||||
{"type": "NODE_FILE_DELETE"},
|
||||
],
|
||||
nodes=[
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"applications": [{"application_name": "WebBrowser"}],
|
||||
"folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}],
|
||||
action_space_cfg = {
|
||||
"action_map": {
|
||||
0: {"action": "do_nothing", "options": {}},
|
||||
1: {
|
||||
"action": "node_application_execute",
|
||||
"options": {"node_name": "client_1", "application_name": "WebBrowser"},
|
||||
},
|
||||
2: {
|
||||
"action": "node_file_delete",
|
||||
"options": {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"},
|
||||
},
|
||||
],
|
||||
max_folders_per_node=2,
|
||||
max_files_per_folder=2,
|
||||
max_services_per_node=2,
|
||||
max_applications_per_node=2,
|
||||
max_nics_per_node=2,
|
||||
max_acl_rules=10,
|
||||
protocols=["TCP", "UDP", "ICMP"],
|
||||
ports=["HTTP", "DNS", "ARP"],
|
||||
act_map={
|
||||
0: {"action": "DONOTHING", "options": {}},
|
||||
1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}},
|
||||
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
|
||||
},
|
||||
)
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
}
|
||||
|
||||
pa = ProbabilisticAgent(
|
||||
agent_name="test_agent",
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
settings={
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
|
||||
game = PrimaiteGame()
|
||||
game.options = PrimaiteGameOptions(ports=[], protocols=[])
|
||||
|
||||
pa_config = {
|
||||
"type": "ProbabilisticAgent",
|
||||
"ref": "ProbabilisticAgent",
|
||||
"team": "BLUE",
|
||||
"action_space": action_space_cfg,
|
||||
"agent_settings": {
|
||||
"action_probabilities": {0: P_DO_NOTHING, 1: P_node_application_execute, 2: P_node_file_delete},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pa = ProbabilisticAgent.from_config(config=pa_config)
|
||||
|
||||
do_nothing_count = 0
|
||||
node_application_execute_count = 0
|
||||
node_file_delete_count = 0
|
||||
for _ in range(N_TRIALS):
|
||||
a = pa.get_action(0)
|
||||
if a == ("DONOTHING", {}):
|
||||
if a == ("do_nothing", {}):
|
||||
do_nothing_count += 1
|
||||
elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}):
|
||||
elif a == ("node_application_execute", {"node_name": "client_1", "application_name": "WebBrowser"}):
|
||||
node_application_execute_count += 1
|
||||
elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}):
|
||||
elif a == ("node_file_delete", {"node_name": "client_1", "folder_name": "downloads", "file_name": "cat.png"}):
|
||||
node_file_delete_count += 1
|
||||
else:
|
||||
raise AssertionError("Probabilistic agent produced an unexpected action.")
|
||||
|
||||
assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING
|
||||
assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE
|
||||
assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE
|
||||
assert MIN_node_application_execute < node_application_execute_count < MAX_node_application_execute
|
||||
assert MIN_node_file_delete < node_file_delete_count < MAX_node_file_delete
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
reward = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
@@ -91,8 +91,8 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "WebBrowser"}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history.append({"outcome": 200})
|
||||
@@ -104,7 +104,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
@@ -114,8 +114,8 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "WebBrowser"}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
@@ -126,8 +126,8 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "WebBrowser"}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
@@ -142,7 +142,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
reward = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history = []
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
@@ -152,8 +152,8 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "WebBrowser"}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
browser_history.append({"outcome": 200})
|
||||
@@ -165,7 +165,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is sticky, it stays at 1.0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"WebBrowser": {"history": browser_history}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -174,8 +174,8 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "WebBrowser"}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
@@ -186,8 +186,8 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "WebBrowser"}
|
||||
request = ["network", "node", "computer", "application", "WebBrowser", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
browser_history.append({"outcome": 404})
|
||||
@@ -207,7 +207,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -216,8 +216,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "DatabaseClient"}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
@@ -228,7 +228,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -237,8 +237,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == 0.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "DatabaseClient"}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
@@ -248,8 +248,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "DatabaseClient"}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
@@ -266,7 +266,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -275,8 +275,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == 0
|
||||
|
||||
# agent did a successful fetch
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "DatabaseClient"}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
@@ -287,7 +287,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
|
||||
# THE IMPORTANT BIT
|
||||
# agent did nothing, because reward is not sticky, it goes back to 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
action, params, request = "do_nothing", {}, ["do_nothing"]
|
||||
response = RequestResponse(status="success", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
last_action_response = AgentHistoryItem(
|
||||
@@ -296,8 +296,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == 1.0
|
||||
|
||||
# agent fails to fetch, get a -1.0 reward
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "DatabaseClient"}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
@@ -307,8 +307,8 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
# agent fails again to fetch, get a -1.0 reward again
|
||||
action = "NODE_APPLICATION_EXECUTE"
|
||||
params = {"node_id": 0, "application_id": 0}
|
||||
action = "node_application_execute"
|
||||
params = {"node_name": "computer", "application_name": "DatabaseClient"}
|
||||
request = ["network", "node", "computer", "application", "DatabaseClient", "execute"]
|
||||
response = RequestResponse(status="failure", data={})
|
||||
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
|
||||
|
||||
@@ -22,12 +22,12 @@ def test_file_scan(file_system):
|
||||
file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
|
||||
assert file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
file.corrupt()
|
||||
|
||||
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
file.scan()
|
||||
|
||||
@@ -46,7 +46,7 @@ def test_file_reveal_to_red_scan(file_system):
|
||||
assert file.revealed_to_red is True
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
|
||||
@pytest.mark.skip(reason="node_file_checkhash not implemented")
|
||||
def test_simulated_file_check_hash(file_system):
|
||||
file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def test_file_scan_request(populated_file_system):
|
||||
|
||||
file.corrupt()
|
||||
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
fs.apply_request(request=["folder", folder.name, "file", file.name, "scan"])
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_file_scan_request(populated_file_system):
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
|
||||
@pytest.mark.skip(reason="node_file_checkhash not implemented")
|
||||
def test_file_checkhash_request(populated_file_system):
|
||||
"""Test that an agent can request a file hash check."""
|
||||
fs, folder, file = populated_file_system
|
||||
@@ -94,7 +94,7 @@ def test_deleted_file_cannot_be_interacted_with(populated_file_system):
|
||||
assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert (
|
||||
fs.get_file(folder_name=folder.name, file_name=file.name).visible_health_status
|
||||
== FileSystemItemHealthStatus.GOOD
|
||||
== FileSystemItemHealthStatus.NONE
|
||||
)
|
||||
|
||||
fs.apply_request(request=["delete", "file", folder.name, file.name])
|
||||
|
||||
@@ -44,25 +44,25 @@ def test_folder_scan(file_system):
|
||||
file2: File = folder.get_file_by_id(file_uuid=list(folder.files)[0])
|
||||
|
||||
assert folder.health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
folder.corrupt()
|
||||
|
||||
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
folder.scan()
|
||||
|
||||
folder.apply_timestep(timestep=0)
|
||||
|
||||
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
folder.apply_timestep(timestep=1)
|
||||
folder.apply_timestep(timestep=2)
|
||||
@@ -120,7 +120,7 @@ def test_folder_corrupt_repair(file_system):
|
||||
assert file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
|
||||
@pytest.mark.skip(reason="node_file_checkhash not implemented")
|
||||
def test_simulated_folder_check_hash(file_system):
|
||||
folder: Folder = file_system.create_folder(folder_name="test_folder")
|
||||
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
|
||||
@@ -29,18 +29,18 @@ def test_folder_scan_request(populated_file_system):
|
||||
|
||||
folder.corrupt()
|
||||
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
fs.apply_request(request=["folder", folder.name, "scan"])
|
||||
|
||||
folder.apply_timestep(timestep=0)
|
||||
|
||||
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file1.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
folder.apply_timestep(timestep=1)
|
||||
folder.apply_timestep(timestep=2)
|
||||
@@ -51,7 +51,7 @@ def test_folder_scan_request(populated_file_system):
|
||||
assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="NODE_FOLDER_CHECKHASH not implemented")
|
||||
@pytest.mark.skip(reason="node_folder_checkhash not implemented")
|
||||
def test_folder_checkhash_request(populated_file_system):
|
||||
"""Test that an agent can request a folder hash check."""
|
||||
fs, folder, file = populated_file_system
|
||||
|
||||
@@ -70,13 +70,13 @@ def test_node_os_scan(node):
|
||||
# add folder and file to node
|
||||
folder: Folder = node.file_system.create_folder(folder_name="test_folder")
|
||||
folder.corrupt()
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert folder.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
file: File = node.file_system.create_file(folder_name="test_folder", file_name="file.txt")
|
||||
file2: File = node.file_system.create_file(folder_name="test_folder", file_name="file2.txt")
|
||||
file.corrupt()
|
||||
file2.corrupt()
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
|
||||
assert file.visible_health_status == FileSystemItemHealthStatus.NONE
|
||||
|
||||
# run os scan
|
||||
node.apply_request(["os", "scan"])
|
||||
|
||||
@@ -128,13 +128,13 @@ def test_c2_handle_switching_port(basic_c2_network):
|
||||
assert c2_server.c2_connection_active is True
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon are configured correctly.
|
||||
assert c2_beacon.c2_config.keep_alive_frequency is 2
|
||||
assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_beacon.config.keep_alive_frequency is 2
|
||||
assert c2_beacon.config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
assert c2_server.c2_config.keep_alive_frequency is 2
|
||||
assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_server.config.keep_alive_frequency is 2
|
||||
assert c2_server.config.masquerade_port is PORT_LOOKUP["HTTP"]
|
||||
assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
# Configuring the C2 Beacon.
|
||||
c2_beacon.configure(
|
||||
@@ -150,11 +150,11 @@ def test_c2_handle_switching_port(basic_c2_network):
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon
|
||||
# Have reconfigured their C2 settings.
|
||||
assert c2_beacon.c2_config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_beacon.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_beacon.config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_beacon.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
assert c2_server.c2_config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_server.c2_config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
assert c2_server.config.masquerade_port is PORT_LOOKUP["FTP"]
|
||||
assert c2_server.config.masquerade_protocol is PROTOCOL_LOOKUP["TCP"]
|
||||
|
||||
|
||||
def test_c2_handle_switching_frequency(basic_c2_network):
|
||||
@@ -174,8 +174,8 @@ def test_c2_handle_switching_frequency(basic_c2_network):
|
||||
assert c2_server.c2_connection_active is True
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon are configured correctly.
|
||||
assert c2_beacon.c2_config.keep_alive_frequency is 2
|
||||
assert c2_server.c2_config.keep_alive_frequency is 2
|
||||
assert c2_beacon.config.keep_alive_frequency is 2
|
||||
assert c2_server.config.keep_alive_frequency is 2
|
||||
|
||||
# Configuring the C2 Beacon.
|
||||
c2_beacon.configure(c2_server_ip_address="192.168.0.1", keep_alive_frequency=10)
|
||||
@@ -186,8 +186,8 @@ def test_c2_handle_switching_frequency(basic_c2_network):
|
||||
|
||||
# Assert to confirm that both the C2 server and the C2 beacon
|
||||
# Have reconfigured their C2 settings.
|
||||
assert c2_beacon.c2_config.keep_alive_frequency is 10
|
||||
assert c2_server.c2_config.keep_alive_frequency is 10
|
||||
assert c2_beacon.config.keep_alive_frequency is 10
|
||||
assert c2_server.config.keep_alive_frequency is 10
|
||||
|
||||
# Now skipping 9 time steps to confirm keep alive inactivity
|
||||
for i in range(9):
|
||||
|
||||
@@ -148,7 +148,7 @@ def test_service_fixing(service):
|
||||
service.fix()
|
||||
assert service.health_state_actual == SoftwareHealthState.FIXING
|
||||
|
||||
for i in range(service.fixing_duration + 1):
|
||||
for i in range(service.config.fixing_duration + 1):
|
||||
service.apply_timestep(i)
|
||||
|
||||
assert service.health_state_actual == SoftwareHealthState.GOOD
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
@@ -10,7 +11,14 @@ from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
|
||||
|
||||
class TestSoftware(Service):
|
||||
class TestSoftware(Service, identifier="TestSoftware"):
|
||||
class ConfigSchema(Service.ConfigSchema):
|
||||
"""ConfigSChema for TestSoftware."""
|
||||
|
||||
type: str = "TestSoftware"
|
||||
|
||||
config: "TestSoftware.ConfigSchema" = Field(default_factory=lambda: TestSoftware.ConfigSchema())
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user