diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 7533f6f3..a5738d76 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -72,6 +72,7 @@ class WebBrowser(Application): """ state = super().describe_state() state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None + return state def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 37a680c8..85660796 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -10,11 +10,13 @@ # 4. Check that the simulation has changed in the way that I expect. # 5. Repeat for all actions. +from typing import Dict, Tuple + import pytest from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import ProxyAgent -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.interface import AbstractAgent, ProxyAgent +from primaite.game.agent.observations import ICSObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -29,6 +31,34 @@ from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.web_server.web_server import WebServer +from primaite.simulator.system.software import SoftwareHealthState + + +class ControlledAgent(AbstractAgent): + """Agent that can be controlled by the tests.""" + + def __init__( + self, + agent_name: str, + action_space: ActionManager, + observation_space: ObservationManager, + reward_function: RewardFunction, + ) -> None: + super().__init__( + agent_name=agent_name, + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + self.most_recent_action: Tuple[str, Dict] + + def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: + """Return the agent's most recent action, formatted in CAOS format.""" + return self.most_recent_action + + def store_action(self, action: Tuple[str, Dict]): + """Store the most recent action.""" + self.most_recent_action = action def install_stuff_to_sim(sim: Simulation): @@ -105,12 +135,47 @@ def install_stuff_to_sim(sim: Simulation): assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser) assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient) - # 5: Return the simulation + # 5: Assert that the simulation starts off in the state that we expect + assert len(sim.network.nodes) == 6 + assert len(sim.network.links) == 5 + # 5.1: Assert the router is correctly configured + r = sim.network.routers[0] + for i, acl_rule in enumerate(r.acl.acl): + if i == 1: + assert acl_rule.src_port == acl_rule.dst_port == Port.DNS + elif i == 3: + assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP + elif i == 22: + assert acl_rule.src_port == acl_rule.dst_port == Port.ARP + elif i == 23: + assert acl_rule.protocol == IPProtocol.ICMP + elif i == 24: + ... + else: + assert acl_rule is None + + # 5.2: Assert the client is correctly configured + c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] + assert c.software_manager.software.get("WebBrowser") is not None + assert c.software_manager.software.get("DNSClient") is not None + assert str(c.ethernet_port[1].ip_address) == "10.0.1.2" + + # 5.3: Assert that server_1 is correctly configured + s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] + assert str(s1.ethernet_port[1].ip_address) == "10.0.2.2" + assert s1.software_manager.software.get("DNSServer") is not None + + # 5.4: Assert that server_2 is correctly configured + s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] + assert str(s2.ethernet_port[1].ip_address) == "10.0.2.3" + assert s2.software_manager.software.get("WebServer") is not None + + # 6: Return the simulation return sim @pytest.fixture -def game(): +def game_and_agent(): """Create a game with a simple agent that can be controlled by the tests.""" game = PrimaiteGame() sim = game.simulation @@ -166,10 +231,10 @@ def game(): ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], act_map={}, ) - observation_space = None - reward_function = None + observation_space = ObservationManager(ICSObservation()) + reward_function = RewardFunction() - test_agent = ProxyAgent( + test_agent = ControlledAgent( agent_name="test_agent", action_space=action_space, observation_space=observation_space, @@ -178,8 +243,47 @@ def game(): game.agents.append(test_agent) - return game, test_agent + return (game, test_agent) -def test_test(game): - assert True +# def test_test(game_and_agent:Tuple[PrimaiteGame, ProxyAgent]): +# game, agent = game_and_agent + + +def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the DoNothingAction can form a request and that it is accepted by the simulation.""" + game, agent = game_and_agent + + action = ("DONOTHING", {}) + agent.store_action(action) + game.step() + + +@pytest.mark.skip(reason="Waiting to merge ticket 2160") +def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): + """ + Test that the NodeServiceScanAction can form a request and that it is accepted by the simulation. + + The health status of applications is not always updated in the state dict, rather the agent needs to perform a scan. + Therefore, we set the web browser to be corrupted, check the state is still good, then perform a scan, and check + that the state changes to the true value. + """ + game, agent = game_and_agent + + browser = game.simulation.network.get_node_by_hostname("client_1").software_manager.software.get("WebBrowser") + browser.health_state_actual = SoftwareHealthState.COMPROMISED + + state_before = game.get_sim_state() + assert ( + game.get_sim_state()["network"]["nodes"]["client_1"]["applications"]["WebBrowser"]["health_state"] + == SoftwareHealthState.GOOD + ) + action = ("NODE_SERVICE_SCAN", {"node_id": 0, "service_id": 0}) + agent.store_action(action) + game.step() + state_after = game.get_sim_state() + pass + assert ( + game.get_sim_state()["network"]["nodes"]["client_1"]["services"]["WebBrowser"]["health_state"] + == SoftwareHealthState.COMPROMISED + )