From 41bc932f524a0e14456f2112bd7013d94e3c5330 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 6 Feb 2024 15:05:44 +0000 Subject: [PATCH] Add reward test. --- tests/conftest.py | 233 ++++++++++++++++++ .../game_layer/test_actions.py | 228 ----------------- .../game_layer/test_rewards.py | 37 +++ 3 files changed, 270 insertions(+), 228 deletions(-) create mode 100644 tests/integration_tests/game_layer/test_rewards.py diff --git a/tests/conftest.py b/tests/conftest.py index c37226a5..510a9df0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,10 @@ import pytest import yaml from primaite import getLogger +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.interface import AbstractAgent +from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.rewards import RewardFunction from primaite.game.game import PrimaiteGame from primaite.session.session import PrimaiteSession @@ -20,9 +24,14 @@ from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.application import Application +from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.core.sys_log import SysLog +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service +from primaite.simulator.system.services.web_server.web_server import WebServer from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path ACTION_SPACE_NODE_VALUES = 1 @@ -237,3 +246,227 @@ def example_network() -> Network: router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) return network + + +class ControlledAgent(AbstractAgent): + """Agent that can be controlled by the tests.""" + + def __init__( + self, + agent_name: str, + action_space: ActionManager, + observation_space: ObservationManager, + reward_function: RewardFunction, + ) -> None: + super().__init__( + agent_name=agent_name, + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + self.most_recent_action: Tuple[str, Dict] + + def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: + """Return the agent's most recent action, formatted in CAOS format.""" + return self.most_recent_action + + def store_action(self, action: Tuple[str, Dict]): + """Store the most recent action.""" + self.most_recent_action = action + + +def install_stuff_to_sim(sim: Simulation): + """Create a simulation with a computer, two servers, two switches, and a router.""" + + # 0: Pull out the network + network = sim.network + + # 1: Set up network hardware + # 1.1: Configure the router + router = Router(hostname="router", num_ports=3, operating_state=NodeOperatingState.ON) + router.power_on() + router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") + router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") + + # 1.2: Create and connect switches + switch_1 = Switch(hostname="switch_1", num_ports=6, operating_state=NodeOperatingState.ON) + switch_1.power_on() + network.connect(endpoint_a=router.ethernet_ports[1], endpoint_b=switch_1.switch_ports[6]) + router.enable_port(1) + switch_2 = Switch(hostname="switch_2", num_ports=6, operating_state=NodeOperatingState.ON) + switch_2.power_on() + network.connect(endpoint_a=router.ethernet_ports[2], endpoint_b=switch_2.switch_ports[6]) + router.enable_port(2) + + # 1.3: Create and connect computer + client_1 = Computer( + hostname="client_1", + ip_address="10.0.1.2", + subnet_mask="255.255.255.0", + default_gateway="10.0.1.1", + operating_state=NodeOperatingState.ON, + ) + client_1.power_on() + network.connect( + endpoint_a=client_1.ethernet_port[1], + endpoint_b=switch_1.switch_ports[1], + ) + + # 1.4: Create and connect servers + server_1 = Server( + hostname="server_1", + ip_address="10.0.2.2", + subnet_mask="255.255.255.0", + default_gateway="10.0.2.1", + operating_state=NodeOperatingState.ON, + ) + server_1.power_on() + network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_2.switch_ports[1]) + + server_2 = Server( + hostname="server_2", + ip_address="10.0.2.3", + subnet_mask="255.255.255.0", + default_gateway="10.0.2.1", + operating_state=NodeOperatingState.ON, + ) + server_2.power_on() + network.connect(endpoint_a=server_2.ethernet_port[1], endpoint_b=switch_2.switch_ports[2]) + + # 2: Configure base ACL + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + + # 3: Install server software + server_1.software_manager.install(DNSServer) + dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa + dns_service.dns_register("www.example.com", server_2.ip_address) + server_2.software_manager.install(WebServer) + + # 3.1: Ensure that the dns clients are configured correctly + client_1.software_manager.software.get("DNSClient").dns_server = server_1.ethernet_port[1].ip_address + server_2.software_manager.software.get("DNSClient").dns_server = server_1.ethernet_port[1].ip_address + + # 4: Check that client came pre-installed with web browser and dns client + assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser) + assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient) + + # 4.1: Create a file on the computer + client_1.file_system.create_file("cat.png", 300, folder_name="downloads") + + # 5: Assert that the simulation starts off in the state that we expect + assert len(sim.network.nodes) == 6 + assert len(sim.network.links) == 5 + # 5.1: Assert the router is correctly configured + r = sim.network.routers[0] + for i, acl_rule in enumerate(r.acl.acl): + if i == 1: + assert acl_rule.src_port == acl_rule.dst_port == Port.DNS + elif i == 3: + assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP + elif i == 22: + assert acl_rule.src_port == acl_rule.dst_port == Port.ARP + elif i == 23: + assert acl_rule.protocol == IPProtocol.ICMP + elif i == 24: + ... + else: + assert acl_rule is None + + # 5.2: Assert the client is correctly configured + c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] + assert c.software_manager.software.get("WebBrowser") is not None + assert c.software_manager.software.get("DNSClient") is not None + assert str(c.ethernet_port[1].ip_address) == "10.0.1.2" + + # 5.3: Assert that server_1 is correctly configured + s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] + assert str(s1.ethernet_port[1].ip_address) == "10.0.2.2" + assert s1.software_manager.software.get("DNSServer") is not None + + # 5.4: Assert that server_2 is correctly configured + s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] + assert str(s2.ethernet_port[1].ip_address) == "10.0.2.3" + assert s2.software_manager.software.get("WebServer") is not None + + # 6: Return the simulation + return sim + + +@pytest.fixture +def game_and_agent(): + """Create a game with a simple agent that can be controlled by the tests.""" + game = PrimaiteGame() + sim = game.simulation + install_stuff_to_sim(sim) + + actions = [ + {"type": "DONOTHING"}, + {"type": "NODE_SERVICE_SCAN"}, + {"type": "NODE_SERVICE_STOP"}, + {"type": "NODE_SERVICE_START"}, + {"type": "NODE_SERVICE_PAUSE"}, + {"type": "NODE_SERVICE_RESUME"}, + {"type": "NODE_SERVICE_RESTART"}, + {"type": "NODE_SERVICE_DISABLE"}, + {"type": "NODE_SERVICE_ENABLE"}, + {"type": "NODE_SERVICE_PATCH"}, + {"type": "NODE_APPLICATION_EXECUTE"}, + {"type": "NODE_FILE_SCAN"}, + {"type": "NODE_FILE_CHECKHASH"}, + {"type": "NODE_FILE_DELETE"}, + {"type": "NODE_FILE_REPAIR"}, + {"type": "NODE_FILE_RESTORE"}, + {"type": "NODE_FILE_CORRUPT"}, + {"type": "NODE_FOLDER_SCAN"}, + {"type": "NODE_FOLDER_CHECKHASH"}, + {"type": "NODE_FOLDER_REPAIR"}, + {"type": "NODE_FOLDER_RESTORE"}, + {"type": "NODE_OS_SCAN"}, + {"type": "NODE_SHUTDOWN"}, + {"type": "NODE_STARTUP"}, + {"type": "NODE_RESET"}, + {"type": "NETWORK_ACL_ADDRULE", "options": {"target_router_hostname": "router"}}, + {"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}}, + {"type": "NETWORK_NIC_ENABLE"}, + {"type": "NETWORK_NIC_DISABLE"}, + ] + + action_space = ActionManager( + game=game, + actions=actions, # ALL POSSIBLE ACTIONS + nodes=[ + { + "node_name": "client_1", + "applications": [{"application_name": "WebBrowser"}], + "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], + }, + {"node_name": "server_1", "services": [{"service_name": "DNSServer"}]}, + {"node_name": "server_2", "services": [{"service_name": "WebServer"}]}, + ], + max_folders_per_node=2, + max_files_per_folder=2, + max_services_per_node=2, + max_applications_per_node=2, + max_nics_per_node=2, + max_acl_rules=10, + protocols=["TCP", "UDP", "ICMP"], + ports=["HTTP", "DNS", "ARP"], + ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], + act_map={}, + ) + observation_space = ObservationManager(ICSObservation()) + reward_function = RewardFunction() + + test_agent = ControlledAgent( + agent_name="test_agent", + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + ) + + game.agents.append(test_agent) + + return (game, test_agent) diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index e771dbd2..c5e09195 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -35,234 +35,6 @@ from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState -class ControlledAgent(AbstractAgent): - """Agent that can be controlled by the tests.""" - - def __init__( - self, - agent_name: str, - action_space: ActionManager, - observation_space: ObservationManager, - reward_function: RewardFunction, - ) -> None: - super().__init__( - agent_name=agent_name, - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) - self.most_recent_action: Tuple[str, Dict] - - def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: - """Return the agent's most recent action, formatted in CAOS format.""" - return self.most_recent_action - - def store_action(self, action: Tuple[str, Dict]): - """Store the most recent action.""" - self.most_recent_action = action - - -def install_stuff_to_sim(sim: Simulation): - """Create a simulation with a computer, two servers, two switches, and a router.""" - - # 0: Pull out the network - network = sim.network - - # 1: Set up network hardware - # 1.1: Configure the router - router = Router(hostname="router", num_ports=3, operating_state=NodeOperatingState.ON) - router.power_on() - router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") - router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") - - # 1.2: Create and connect switches - switch_1 = Switch(hostname="switch_1", num_ports=6, operating_state=NodeOperatingState.ON) - switch_1.power_on() - network.connect(endpoint_a=router.ethernet_ports[1], endpoint_b=switch_1.switch_ports[6]) - router.enable_port(1) - switch_2 = Switch(hostname="switch_2", num_ports=6, operating_state=NodeOperatingState.ON) - switch_2.power_on() - network.connect(endpoint_a=router.ethernet_ports[2], endpoint_b=switch_2.switch_ports[6]) - router.enable_port(2) - - # 1.3: Create and connect computer - client_1 = Computer( - hostname="client_1", - ip_address="10.0.1.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.1.1", - operating_state=NodeOperatingState.ON, - ) - client_1.power_on() - network.connect( - endpoint_a=client_1.ethernet_port[1], - endpoint_b=switch_1.switch_ports[1], - ) - - # 1.4: Create and connect servers - server_1 = Server( - hostname="server_1", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - operating_state=NodeOperatingState.ON, - ) - server_1.power_on() - network.connect(endpoint_a=server_1.ethernet_port[1], endpoint_b=switch_2.switch_ports[1]) - - server_2 = Server( - hostname="server_2", - ip_address="10.0.2.3", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - operating_state=NodeOperatingState.ON, - ) - server_2.power_on() - network.connect(endpoint_a=server_2.ethernet_port[1], endpoint_b=switch_2.switch_ports[2]) - - # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) - - # 3: Install server software - server_1.software_manager.install(DNSServer) - dns_service: DNSServer = server_1.software_manager.software.get("DNSServer") # noqa - dns_service.dns_register("www.example.com", server_2.ip_address) - server_2.software_manager.install(WebServer) - - # 3.1: Ensure that the dns clients are configured correctly - client_1.software_manager.software.get("DNSClient").dns_server = server_1.ethernet_port[1].ip_address - server_2.software_manager.software.get("DNSClient").dns_server = server_1.ethernet_port[1].ip_address - - # 4: Check that client came pre-installed with web browser and dns client - assert isinstance(client_1.software_manager.software.get("WebBrowser"), WebBrowser) - assert isinstance(client_1.software_manager.software.get("DNSClient"), DNSClient) - - # 4.1: Create a file on the computer - client_1.file_system.create_file("cat.png", 300, folder_name="downloads") - - # 5: Assert that the simulation starts off in the state that we expect - assert len(sim.network.nodes) == 6 - assert len(sim.network.links) == 5 - # 5.1: Assert the router is correctly configured - r = sim.network.routers[0] - for i, acl_rule in enumerate(r.acl.acl): - if i == 1: - assert acl_rule.src_port == acl_rule.dst_port == Port.DNS - elif i == 3: - assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP - elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port.ARP - elif i == 23: - assert acl_rule.protocol == IPProtocol.ICMP - elif i == 24: - ... - else: - assert acl_rule is None - - # 5.2: Assert the client is correctly configured - c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] - assert c.software_manager.software.get("WebBrowser") is not None - assert c.software_manager.software.get("DNSClient") is not None - assert str(c.ethernet_port[1].ip_address) == "10.0.1.2" - - # 5.3: Assert that server_1 is correctly configured - s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] - assert str(s1.ethernet_port[1].ip_address) == "10.0.2.2" - assert s1.software_manager.software.get("DNSServer") is not None - - # 5.4: Assert that server_2 is correctly configured - s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] - assert str(s2.ethernet_port[1].ip_address) == "10.0.2.3" - assert s2.software_manager.software.get("WebServer") is not None - - # 6: Return the simulation - return sim - - -@pytest.fixture -def game_and_agent(): - """Create a game with a simple agent that can be controlled by the tests.""" - game = PrimaiteGame() - sim = game.simulation - install_stuff_to_sim(sim) - - actions = [ - {"type": "DONOTHING"}, - {"type": "NODE_SERVICE_SCAN"}, - {"type": "NODE_SERVICE_STOP"}, - {"type": "NODE_SERVICE_START"}, - {"type": "NODE_SERVICE_PAUSE"}, - {"type": "NODE_SERVICE_RESUME"}, - {"type": "NODE_SERVICE_RESTART"}, - {"type": "NODE_SERVICE_DISABLE"}, - {"type": "NODE_SERVICE_ENABLE"}, - {"type": "NODE_SERVICE_PATCH"}, - {"type": "NODE_APPLICATION_EXECUTE"}, - {"type": "NODE_FILE_SCAN"}, - {"type": "NODE_FILE_CHECKHASH"}, - {"type": "NODE_FILE_DELETE"}, - {"type": "NODE_FILE_REPAIR"}, - {"type": "NODE_FILE_RESTORE"}, - {"type": "NODE_FILE_CORRUPT"}, - {"type": "NODE_FOLDER_SCAN"}, - {"type": "NODE_FOLDER_CHECKHASH"}, - {"type": "NODE_FOLDER_REPAIR"}, - {"type": "NODE_FOLDER_RESTORE"}, - {"type": "NODE_OS_SCAN"}, - {"type": "NODE_SHUTDOWN"}, - {"type": "NODE_STARTUP"}, - {"type": "NODE_RESET"}, - {"type": "NETWORK_ACL_ADDRULE", "options": {"target_router_hostname": "router"}}, - {"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}}, - {"type": "NETWORK_NIC_ENABLE"}, - {"type": "NETWORK_NIC_DISABLE"}, - ] - - action_space = ActionManager( - game=game, - actions=actions, # ALL POSSIBLE ACTIONS - nodes=[ - { - "node_name": "client_1", - "applications": [{"application_name": "WebBrowser"}], - "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], - }, - {"node_name": "server_1", "services": [{"service_name": "DNSServer"}]}, - {"node_name": "server_2", "services": [{"service_name": "WebServer"}]}, - ], - max_folders_per_node=2, - max_files_per_folder=2, - max_services_per_node=2, - max_applications_per_node=2, - max_nics_per_node=2, - max_acl_rules=10, - protocols=["TCP", "UDP", "ICMP"], - ports=["HTTP", "DNS", "ARP"], - ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], - act_map={}, - ) - observation_space = ObservationManager(ICSObservation()) - reward_function = RewardFunction() - - test_agent = ControlledAgent( - agent_name="test_agent", - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) - - game.agents.append(test_agent) - - return (game, test_agent) - - -# def test_test(game_and_agent:Tuple[PrimaiteGame, ProxyAgent]): -# game, agent = game_and_agent - - def test_do_nothing_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): """Test that the DoNothingAction can form a request and that it is accepted by the simulation.""" game, agent = game_and_agent diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py new file mode 100644 index 00000000..c084512f --- /dev/null +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -0,0 +1,37 @@ +from primaite.game.agent.rewards import RewardFunction, WebpageUnavailablePenalty +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from tests.conftest import ControlledAgent + + +def test_WebpageUnavailablePenalty(game_and_agent): + """Test that we get the right reward for failing to fetch a website.""" + game, agent = game_and_agent + agent: ControlledAgent + comp = WebpageUnavailablePenalty(node_hostname="client_1") + + agent.reward_function.register_component(comp, 0.7) + action = ("DONOTHING", {}) + agent.store_action(action) + game.step() + + # client 1 has not attempted to fetch webpage yet! + assert agent.reward_function.current_reward == 0.0 + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + browser = client_1.software_manager.software.get("WebBrowser") + browser.run() + browser.target_url = "http://www.example.com" + assert browser.get_webpage() + action = ("DONOTHING", {}) + agent.store_action(action) + game.step() + assert agent.reward_function.current_reward == 0.7 + + router: Router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.DENY, protocol=IPProtocol.TCP, src_port=Port.HTTP, dst_port=Port.HTTP) + assert not browser.get_webpage() + agent.store_action(action) + game.step() + assert agent.reward_function.current_reward == -0.7