From 30c177c2722ea43a80dbf918d2839b883f57a63c Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 23 Jan 2025 17:07:15 +0000 Subject: [PATCH] #2887 - Additional test failure fixes --- .../simulator/network/hardware/base.py | 2 +- .../hardware/nodes/network/wireless_router.py | 8 +-- tests/conftest.py | 57 ++++++++++--------- .../observations/test_firewall_observation.py | 5 +- .../observations/test_link_observations.py | 10 ++-- .../observations/test_nic_observations.py | 6 +- .../observations/test_node_observations.py | 4 +- .../observations/test_router_observation.py | 4 +- .../test_software_observations.py | 4 +- .../game_layer/test_action_mask.py | 1 + .../game_layer/test_actions.py | 1 + .../network/test_airspace_config.py | 1 + .../network/test_broadcast.py | 47 ++++++++------- .../_system/_applications/test_web_browser.py | 22 +++---- 14 files changed, 88 insertions(+), 84 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index f68b627a..d462f75c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1760,7 +1760,7 @@ class Node(SimComponent, ABC): self.software_manager.install(application_class) application_instance = self.software_manager.software.get(application_name) self.applications[application_instance.uuid] = application_instance - _LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}") + _LOGGER.debug(f"Added application {application_instance.name} to node {self.config.hostname}") self._application_request_manager.add_request( application_name, RequestType(func=application_instance._request_manager) ) diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index 2c4b5976..75e4d5ea 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -126,16 +126,16 @@ class WirelessRouter(Router, identifier="wireless_router"): config: "WirelessRouter.ConfigSchema" = Field(default_factory=lambda: WirelessRouter.ConfigSchema()) - class ConfigSchema(Router.ConfigSChema): + class ConfigSchema(Router.ConfigSchema): """Configuration Schema for WirelessRouter nodes within PrimAITE.""" hostname: str = "WirelessRouter" - def __init__(self, hostname: str, airspace: AirSpace, **kwargs): - super().__init__(hostname=hostname, num_ports=0, airspace=airspace, **kwargs) + def __init__(self, **kwargs): + super().__init__(hostname=kwargs["config"].hostname, num_ports=0, airspace=kwargs["config"].airspace, **kwargs) self.connect_nic( - WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=airspace) + WirelessAccessPoint(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", airspace=kwargs["config"].airspace) ) self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) diff --git a/tests/conftest.py b/tests/conftest.py index fc86bb4d..1bdc217c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -348,29 +348,29 @@ def install_stuff_to_sim(sim: Simulation): # 1: Set up network hardware # 1.1: Configure the router - router = Router(hostname="router", num_ports=3, start_up_duration=0) + router = Router.from_config(config={"type":"router", "hostname":"router", "num_ports":3, "start_up_duration":0}) router.power_on() router.configure_port(port=1, ip_address="10.0.1.1", subnet_mask="255.255.255.0") router.configure_port(port=2, ip_address="10.0.2.1", subnet_mask="255.255.255.0") # 1.2: Create and connect switches - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1 = Switch.from_config(config={"type":"switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) switch_1.power_on() network.connect(endpoint_a=router.network_interface[1], endpoint_b=switch_1.network_interface[6]) router.enable_port(1) - switch_2 = Switch(hostname="switch_2", num_ports=6, start_up_duration=0) + switch_2 = Switch.from_config(config={"type":"switch", "hostname":"switch_2", "num_ports":6, "start_up_duration":0}) switch_2.power_on() network.connect(endpoint_a=router.network_interface[2], endpoint_b=switch_2.network_interface[6]) router.enable_port(2) # 1.3: Create and connect computer - client_1 = Computer( - hostname="client_1", - ip_address="10.0.1.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.1.1", - start_up_duration=0, - ) + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address":"10.0.1.2", + "subnet_mask":"255.255.255.0", + "default_gateway": "10.0.1.1", + "start_up_duration":0} + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() network.connect( endpoint_a=client_1.network_interface[1], @@ -378,23 +378,26 @@ def install_stuff_to_sim(sim: Simulation): ) # 1.4: Create and connect servers - server_1 = Server( - hostname="server_1", - ip_address="10.0.2.2", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + server_1_cfg = {"type": "server", + "hostname":"server_1", + "ip_address": "10.0.2.2", + "subnet_mask":"255.255.255.0", + "default_gateway":"10.0.2.1", + "start_up_duration": 0} + + + server_1: Server = Server.from_config(config=server_1_cfg) server_1.power_on() network.connect(endpoint_a=server_1.network_interface[1], endpoint_b=switch_2.network_interface[1]) + server_2_cfg = {"type": "server", + "hostname":"server_2", + "ip_address": "10.0.2.3", + "subnet_mask":"255.255.255.0", + "default_gateway":"10.0.2.1", + "start_up_duration": 0} - server_2 = Server( - hostname="server_2", - ip_address="10.0.2.3", - subnet_mask="255.255.255.0", - default_gateway="10.0.2.1", - start_up_duration=0, - ) + + server_2: Server = Server.from_config(config=server_2_cfg) server_2.power_on() network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) @@ -442,18 +445,18 @@ def install_stuff_to_sim(sim: Simulation): assert acl_rule is None # 5.2: Assert the client is correctly configured - c: Computer = [node for node in sim.network.nodes.values() if node.hostname == "client_1"][0] + c: Computer = [node for node in sim.network.nodes.values() if node.config.hostname == "client_1"][0] assert c.software_manager.software.get("WebBrowser") is not None assert c.software_manager.software.get("DNSClient") is not None assert str(c.network_interface[1].ip_address) == "10.0.1.2" # 5.3: Assert that server_1 is correctly configured - s1: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_1"][0] + s1: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_1"][0] assert str(s1.network_interface[1].ip_address) == "10.0.2.2" assert s1.software_manager.software.get("DNSServer") is not None # 5.4: Assert that server_2 is correctly configured - s2: Server = [node for node in sim.network.nodes.values() if node.hostname == "server_2"][0] + s2: Server = [node for node in sim.network.nodes.values() if node.config.hostname == "server_2"][0] assert str(s2.network_interface[1].ip_address) == "10.0.2.3" assert s2.software_manager.software.get("WebServer") is not None diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 97608132..6b0d4359 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -25,7 +25,8 @@ def check_default_rules(acl_obs): def test_firewall_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON) + firewall_cfg = {"type": "firewall", "hostname": "firewall", "opertating_state": NodeOperatingState.ON} + firewall = Firewall.from_config(config=firewall_cfg) firewall_observation = FirewallObservation( where=[], num_rules=7, @@ -116,7 +117,7 @@ def test_firewall_observation(): assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) # connect a switch to the firewall and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch", "num_ports":1, "operating_state":NodeOperatingState.ON}) link = net.connect(firewall.network_interface[1], switch.network_interface[1]) assert firewall.network_interface[1].enabled observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index 630e29ea..b5cd6134 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -56,12 +56,12 @@ def test_link_observation(): """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) - switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) - computer_1 = Computer( - hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0 + switch: Switch = Switch.from_config(config={"type":"switch", "hostname":"switch", "num_ports":5, "operating_state":NodeOperatingState.ON}) + computer_1: Computer = Computer.from_config(config={"type": "computer", + "hostname":"computer_1", "ip_address":"10.0.0.1", "subnet_mask":"255.255.255.0", "start_up_duration":0} ) - computer_2 = Computer( - hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + computer_2: Computer = Computer.from_config(config={"type":"computer", + "hostname":"computer_2", "ip_address":"10.0.0.2", "subnet_mask":"255.255.255.0", "start_up_duration":0} ) computer_1.power_on() computer_2.power_on() diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index bd9417ba..2a311853 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -75,7 +75,7 @@ def test_nic(simulation): nic: NIC = pc.network_interface[1] - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { @@ -108,7 +108,7 @@ def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default @@ -163,7 +163,7 @@ def test_nic_monitored_traffic(simulation): pc2: Computer = simulation.network.get_node_by_hostname("client_2") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic + where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic ) simulation.pre_timestep(0) # apply timestep to whole sim diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 63ca8f6b..09eb3fe4 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -25,7 +25,7 @@ def test_host_observation(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") host_obs = HostObservation( - where=["network", "nodes", pc.hostname], + where=["network", "nodes", pc.config.hostname], num_applications=0, num_files=1, num_folders=1, @@ -56,7 +56,7 @@ def test_host_observation(simulation): observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 4 # shutting down - for i in range(pc.shut_down_duration + 1): + for i in range(pc.config.shut_down_duration + 1): pc.apply_timestep(i) observation_state = host_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index f4bfb193..131af57f 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -16,7 +16,7 @@ from primaite.utils.validation.port import PORT_LOOKUP def test_router_observation(): """Test adding/removing acl rules and enabling/disabling ports.""" net = Network() - router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON) + router = Router.from_config(config={"type": "router", "hostname":"router", "num_ports":5, "operating_state":NodeOperatingState.ON}) ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] acl = ACLObservation( @@ -89,7 +89,7 @@ def test_router_observation(): assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) # connect a switch to the router and check that only the correct port is updated - switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + switch: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch", "num_ports":1, "operating_state":NodeOperatingState.ON}) link = net.connect(router.network_interface[1], switch.network_interface[1]) assert router.network_interface[1].enabled observed_output = router_observation.observe(router.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 291ee395..7957625a 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,7 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("NTPServer") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "NTPServer"]) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +54,7 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"]) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) diff --git a/tests/integration_tests/game_layer/test_action_mask.py b/tests/integration_tests/game_layer/test_action_mask.py index 75965f16..ebba1119 100644 --- a/tests/integration_tests/game_layer/test_action_mask.py +++ b/tests/integration_tests/game_layer/test_action_mask.py @@ -3,6 +3,7 @@ from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests.conftest import TEST_ASSETS_ROOT CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 5a308cf8..9d9b528c 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -17,6 +17,7 @@ from typing import Tuple import pytest import yaml +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv diff --git a/tests/integration_tests/network/test_airspace_config.py b/tests/integration_tests/network/test_airspace_config.py index e8abc0f2..fd3f6f28 100644 --- a/tests/integration_tests/network/test_airspace_config.py +++ b/tests/integration_tests/network/test_airspace_config.py @@ -2,6 +2,7 @@ import yaml from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter from tests import TEST_ASSETS_ROOT diff --git a/tests/integration_tests/network/test_broadcast.py b/tests/integration_tests/network/test_broadcast.py index ed40334f..5c30d2ac 100644 --- a/tests/integration_tests/network/test_broadcast.py +++ b/tests/integration_tests/network/test_broadcast.py @@ -84,44 +84,47 @@ class BroadcastTestClient(Application, identifier="BroadcastTestClient"): def broadcast_network() -> Network: network = Network() - client_1 = Computer( - hostname="client_1", - ip_address="192.168.1.2", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_1_cfg = {"type": "computer", + "hostname": "client_1", + "ip_address":"192.168.1.2", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration":0} + + client_1: Computer = Computer.from_config(config=client_1_cfg) client_1.power_on() client_1.software_manager.install(BroadcastTestClient) application_1 = client_1.software_manager.software["BroadcastTestClient"] application_1.run() + client_2_cfg = {"type": "computer", + "hostname": "client_2", + "ip_address":"192.168.1.3", + "subnet_mask": "255.255.255.0", + "default_gateway": "192.168.1.1", + "start_up_duration":0} - client_2 = Computer( - hostname="client_2", - ip_address="192.168.1.3", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + client_2: Computer = Computer.from_config(config=client_2_cfg) client_2.power_on() client_2.software_manager.install(BroadcastTestClient) application_2 = client_2.software_manager.software["BroadcastTestClient"] application_2.run() - server_1 = Server( - hostname="server_1", - ip_address="192.168.1.1", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + server_1_cfg = {"type": "server", + "hostname": "server_1", + "ip_address":"192.168.1.1", + "subnet_mask": "255.255.255.0", + "default_gateway":"192.168.1.1", + "start_up_duration": 0} + + server_1 :Server = Server.from_config(config=server_1_cfg) + server_1.power_on() server_1.software_manager.install(BroadcastTestService) service: BroadcastTestService = server_1.software_manager.software["BroadcastService"] service.start() - switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0) + switch_1: Switch = Switch.from_config(config={"type": "switch", "hostname":"switch_1", "num_ports":6, "start_up_duration":0}) switch_1.power_on() network.connect(endpoint_a=client_1.network_interface[1], endpoint_b=switch_1.network_interface[1]) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index f78b3261..85cd369f 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -12,13 +12,10 @@ from primaite.utils.validation.port import PORT_LOOKUP @pytest.fixture(scope="function") def web_browser() -> WebBrowser: - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = {"type": "computer", "hostname": "web_client", "ip_address": "192.168.1.11", "subnet_mask": "255.255.255.0", "default_gateway": "192.168.1.1", "start_up_duration": 0} + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") @@ -28,13 +25,10 @@ def web_browser() -> WebBrowser: def test_create_web_client(): - computer = Computer( - hostname="web_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - start_up_duration=0, - ) + computer_cfg = {"type": "computer", "hostname": "web_client", "ip_address": "192.168.1.11", "subnet_mask": "255.255.255.0", "default_gateway": "192.168.1.1", "start_up_duration": 0} + + computer: Computer = Computer.from_config(config=computer_cfg) + computer.power_on() # Web Browser should be pre-installed in computer web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")