Add integration tests

This commit is contained in:
Marek Wolan
2024-01-04 12:47:35 +00:00
parent 25c8ec2ec9
commit 528e3b22a9
2 changed files with 115 additions and 10 deletions

View File

@@ -72,6 +72,7 @@ class WebBrowser(Application):
"""
state = super().describe_state()
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
return state
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""

View File

@@ -10,11 +10,13 @@
# 4. Check that the simulation has changed in the way that I expect.
# 5. Repeat for all actions.
from typing import Dict, Tuple
import pytest
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
@@ -29,6 +31,34 @@ from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import SoftwareHealthState
class ControlledAgent(AbstractAgent):
"""Agent that can be controlled by the tests."""
def __init__(
self,
agent_name: str,
action_space: ActionManager,
observation_space: ObservationManager,
reward_function: RewardFunction,
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: Tuple[str, Dict]
def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]:
"""Return the agent's most recent action, formatted in CAOS format."""
return self.most_recent_action
def store_action(self, action: Tuple[str, Dict]):
"""Store the most recent action."""
self.most_recent_action = action
def install_stuff_to_sim(sim: Simulation):
@@ -105,12 +135,47 @@ def install_stuff_to_sim(sim: Simulation):
assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser)
assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient)
# 5: Return the simulation
# 5: Assert that the simulation starts off in the state that we expect
assert len(sim.network.nodes) == 6
assert len(sim.network.links) == 5
# 5.1: Assert the router is correctly configured
r = sim.network.routers[0]
for i, acl_rule in enumerate(r.acl.acl):
if i == 1:
assert acl_rule.src_port == acl_rule.dst_port == Port.DNS
elif i == 3:
assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP
elif i == 22:
assert acl_rule.src_port == acl_rule.dst_port == Port.ARP
elif i == 23:
assert acl_rule.protocol == IPProtocol.ICMP
elif i == 24:
...
else:
assert acl_rule is None
# 5.2: Assert the client is correctly configured
c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0]
assert c.software_manager.software.get("WebBrowser") is not None
assert c.software_manager.software.get("DNSClient") is not None
assert str(c.ethernet_port[1].ip_address) == "10.0.1.2"
# 5.3: Assert that server_1 is correctly configured
s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0]
assert str(s1.ethernet_port[1].ip_address) == "10.0.2.2"
assert s1.software_manager.software.get("DNSServer") is not None
# 5.4: Assert that server_2 is correctly configured
s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0]
assert str(s2.ethernet_port[1].ip_address) == "10.0.2.3"
assert s2.software_manager.software.get("WebServer") is not None
# 6: Return the simulation
return sim
@pytest.fixture
def game():
def game_and_agent():
"""Create a game with a simple agent that can be controlled by the tests."""
game = PrimaiteGame()
sim = game.simulation
@@ -166,10 +231,10 @@ def game():
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = None
reward_function = None
observation_space = ObservationManager(ICSObservation())
reward_function = RewardFunction()
test_agent = ProxyAgent(
test_agent = ControlledAgent(
agent_name="test_agent",
action_space=action_space,
observation_space=observation_space,
@@ -178,8 +243,47 @@ def game():
game.agents.append(test_agent)
return game, test_agent
return (game, test_agent)
def test_test(game):
assert True
# def test_test(game_and_agent:Tuple[PrimaiteGame, ProxyAgent]):
# game, agent = game_and_agent
def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the DoNothingAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
action = ("DONOTHING", {})
agent.store_action(action)
game.step()
@pytest.mark.skip(reason="Waiting to merge ticket 2160")
def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""
Test that the NodeServiceScanAction can form a request and that it is accepted by the simulation.
The health status of applications is not always updated in the state dict, rather the agent needs to perform a scan.
Therefore, we set the web browser to be corrupted, check the state is still good, then perform a scan, and check
that the state changes to the true value.
"""
game, agent = game_and_agent
browser = game.simulation.network.get_node_by_hostname("client_1").software_manager.software.get("WebBrowser")
browser.health_state_actual = SoftwareHealthState.COMPROMISED
state_before = game.get_sim_state()
assert (
game.get_sim_state()["network"]["nodes"]["client_1"]["applications"]["WebBrowser"]["health_state"]
== SoftwareHealthState.GOOD
)
action = ("NODE_SERVICE_SCAN", {"node_id": 0, "service_id": 0})
agent.store_action(action)
game.step()
state_after = game.get_sim_state()
pass
assert (
game.get_sim_state()["network"]["nodes"]["client_1"]["services"]["WebBrowser"]["health_state"]
== SoftwareHealthState.COMPROMISED
)