diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 9070270a..8a944c7f 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -6,6 +6,9 @@ trigger: - bugfix/* - release/* +pr: + autoCancel: true + drafts: false parameters: # https://stackoverflow.com/a/70046417 - name: matrix @@ -85,6 +88,33 @@ stages: primaite setup displayName: 'Perform PrimAITE Setup' + - task: UseDotNet@2 + displayName: 'Install dotnet dependencies' + inputs: + packageType: 'sdk' + version: '2.1.x' + - script: | - pytest -n auto - displayName: 'Run tests' + coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml + coverage xml -o coverage.xml -i + coverage html -d htmlcov -i + displayName: 'Run tests and code coverage' + + - task: PublishTestResults@2 + condition: succeededOrFailed() + inputs: + testRunner: JUnit + testResultsFiles: 'junit/**.xml' + testRunTitle: 'Publish test results' + + - publish: $(System.DefaultWorkingDirectory)/htmlcov/ + # publish the html report - so we can debug the coverage if needed + condition: ${{ item.every_time }} # should only be run once + artifact: coverage_report + + - task: PublishCodeCoverageResults@2 + # publish the code coverage so it can be viewed in the run coverage page + condition: ${{ item.every_time }} # should only be run once + inputs: + codeCoverageTool: Cobertura + summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml' diff --git a/.gitignore b/.gitignore index a6404ac6..892751d9 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports +junit/ htmlcov/ .tox/ .nox/ diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/data_manipulation_bot.rst index 5180974f..e9cfde71 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/data_manipulation_bot.rst @@ -54,7 +54,7 @@ Example ) network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index 1e5a0b6b..cd6b0aa3 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -28,7 +28,7 @@ See :ref:`Node Start up and Shut down` node.software_manager.install(WebServer) - web_server: WebServer = node.software_manager.software["WebServer"] + web_server: WebServer = node.software_manager.software.get("WebServer") assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install node.power_off() diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index c70d4d66..8eed3ba4 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -424,7 +424,7 @@ class NetworkACLAddRuleAction(AbstractAction): elif permission == 2: permission_str = "DENY" else: - _LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.") + _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") if protocol_id == 0: return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index 93fd81b8..767514b4 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -264,7 +264,7 @@ class FolderObservation(AbstractObservation): while len(self.files) > num_files_per_folder: truncated_file = self.files.pop() msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warn(msg) + _LOGGER.warning(msg) self.default_observation = { "health_status": 0, @@ -438,7 +438,7 @@ class NodeObservation(AbstractObservation): while len(self.services) > num_services_per_node: truncated_service = self.services.pop() msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warn(msg) + _LOGGER.warning(msg) # truncate service list self.folders: List[FolderObservation] = folders @@ -448,7 +448,7 @@ class NodeObservation(AbstractObservation): while len(self.folders) > num_folders_per_node: truncated_folder = self.folders.pop() msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warn(msg) + _LOGGER.warning(msg) self.nics: List[NicObservation] = nics while len(self.nics) < num_nics_per_node: @@ -456,7 +456,7 @@ class NodeObservation(AbstractObservation): while len(self.nics) > num_nics_per_node: truncated_nic = self.nics.pop() msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warn(msg) + _LOGGER.warning(msg) self.logon_status: bool = logon_status diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 71945a24..b7a5e9be 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -210,7 +210,7 @@ class WebServer404Penalty(AbstractReward): f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not " "found in reward config." ) - _LOGGER.warn(msg) + _LOGGER.warning(msg) return DummyReward() # TODO: should we error out with incorrect inputs? Probably! node_uuid = game.ref_map_nodes[node_ref] service_uuid = game.ref_map_services[service_ref] @@ -219,7 +219,7 @@ class WebServer404Penalty(AbstractReward): f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" " found in the simulator." ) - _LOGGER.warn(msg) + _LOGGER.warning(msg) return DummyReward() # TODO: consider erroring here as well return cls(node_uuid=node_uuid, service_uuid=service_uuid) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 08779d96..98a7e8db 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -252,6 +252,6 @@ class SimComponent(BaseModel): def parent(self, new_parent: Union["SimComponent", None]) -> None: if self._parent and new_parent: msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}" - _LOGGER.warn(msg) + _LOGGER.warning(msg) raise RuntimeWarning(msg) self._parent = new_parent diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index 1402a474..d9dad06a 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -72,7 +72,7 @@ class Account(SimComponent): "num_group_changes": self.num_group_changes, "username": self.username, "password": self.password, - "account_type": self.account_type.name, + "account_type": self.account_type.value, "enabled": self.enabled, } ) diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index 25a584c4..c2eb0d2d 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -53,7 +53,10 @@ class FileSystem(SimComponent): original_folder_uuids = self._original_state["original_folder_uuids"] for uuid in original_folder_uuids: if uuid in self.deleted_folders: - self.folders[uuid] = self.deleted_folders.pop(uuid) + folder = self.deleted_folders[uuid] + self.deleted_folders.pop(uuid) + self.folders[uuid] = folder + self._folders_by_name[folder.name] = folder # Clear any other deleted folders that aren't original (have been created by agent) self.deleted_folders.clear() @@ -62,7 +65,9 @@ class FileSystem(SimComponent): current_folder_uuids = list(self.folders.keys()) for uuid in current_folder_uuids: if uuid not in original_folder_uuids: + folder = self.folders[uuid] self.folders.pop(uuid) + self._folders_by_name.pop(folder.name) # Now reset all remaining folders for folder in self.folders.values(): diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 8fca4368..fd18e154 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -75,7 +75,10 @@ class Folder(FileSystemItemABC): original_file_uuids = self._original_state["original_file_uuids"] for uuid in original_file_uuids: if uuid in self.deleted_files: - self.files[uuid] = self.deleted_files.pop(uuid) + file = self.deleted_files[uuid] + self.deleted_files.pop(uuid) + self.files[uuid] = file + self._files_by_name[file.name] = file # Clear any other deleted files that aren't original (have been created by agent) self.deleted_files.clear() @@ -84,7 +87,9 @@ class Folder(FileSystemItemABC): current_file_uuids = list(self.files.keys()) for uuid in current_file_uuids: if uuid not in original_file_uuids: + file = self.files[uuid] self.files.pop(uuid) + self._files_by_name.pop(file.name) # Now reset all remaining files for file in self.files.values(): diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index b7bd2e95..4cd9c8d3 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -51,14 +51,22 @@ def client_server_routed() -> Network: # Client 1 client_1 = Computer( - hostname="client_1", ip_address="192.168.2.2", subnet_mask="255.255.255.0", default_gateway="192.168.2.1" + hostname="client_1", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.2.1", + operating_state=NodeOperatingState.ON, ) client_1.power_on() network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) # Server 1 server_1 = Server( - hostname="server_1", ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" + hostname="server_1", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, ) server_1.power_on() network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1]) @@ -139,7 +147,7 @@ def arcd_uc2_network() -> Network: client_1.power_on() network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) client_1.software_manager.install(DataManipulationBot) - db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") db_manipulation_bot.configure( server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE", @@ -157,7 +165,7 @@ def arcd_uc2_network() -> Network: operating_state=NodeOperatingState.ON, ) client_2.power_on() - web_browser = client_2.software_manager.software["WebBrowser"] + web_browser = client_2.software_manager.software.get("WebBrowser") web_browser.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) @@ -241,7 +249,7 @@ def arcd_uc2_network() -> Network: # noqa ] database_server.software_manager.install(DatabaseService) - database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa + database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) database_service._process_sql(ddl, None) # noqa @@ -260,7 +268,7 @@ def arcd_uc2_network() -> Network: web_server.power_on() web_server.software_manager.install(DatabaseClient) - database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + database_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") database_client.configure(server_ip_address=IPv4Address("192.168.1.14")) network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) database_client.run() @@ -269,7 +277,7 @@ def arcd_uc2_network() -> Network: web_server.software_manager.install(WebServer) # register the web_server to a domain - dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa + dns_server_service: DNSServer = domain_controller.software_manager.software.get("DNSServer") # noqa dns_server_service.dns_register("arcd.com", web_server.ip_address) # Backup Server diff --git a/src/primaite/simulator/network/utils.py b/src/primaite/simulator/network/utils.py index 496f5e13..33085bd6 100644 --- a/src/primaite/simulator/network/utils.py +++ b/src/primaite/simulator/network/utils.py @@ -5,6 +5,8 @@ def convert_bytes_to_megabits(B: Union[int, float]) -> float: # noqa - Keep it """ Convert Bytes (file size) to Megabits (data transfer). + Technically Mebibits - but for simplicity sake, we'll call it megabit + :param B: The file size in Bytes. :return: File bits to transfer in Megabits. """ diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 7b63d26e..f57246fc 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -73,7 +73,8 @@ class DatabaseClient(Application): if not self.connected: return self._connect(self.server_ip_address, self.server_password) - return False + # already connected + return True def _connect( self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False @@ -107,7 +108,7 @@ class DatabaseClient(Application): def disconnect(self): """Disconnect from the Database Service.""" - if self.connected and self.operating_state.RUNNING: + if self.connected and self.operating_state is ApplicationOperatingState.RUNNING: software_manager: SoftwareManager = self.software_manager software_manager.send_payload_to_session_manager( payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port @@ -186,6 +187,9 @@ class DatabaseClient(Application): :param session_id: The session id the payload relates to. :return: True. """ + if not self._can_perform_action(): + return False + if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "connect_response": self.connected = payload["response"] == True diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 1531314d..7533f6f3 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -99,7 +99,7 @@ class WebBrowser(Application): return False # get the IP address of the domain name via DNS - dns_client: DNSClient = self.software_manager.software["DNSClient"] + dns_client: DNSClient = self.software_manager.software.get("DNSClient") domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) # if domain does not exist, the request fails diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index bba4e777..61cf1560 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -80,7 +80,7 @@ class DatabaseService(Service): return False software_manager: SoftwareManager = self.software_manager - ftp_client_service: FTPClient = software_manager.software["FTPClient"] + ftp_client_service: FTPClient = software_manager.software.get("FTPClient") # send backup copy of database file to FTP server response = ftp_client_service.send_file( @@ -104,7 +104,7 @@ class DatabaseService(Service): return False software_manager: SoftwareManager = self.software_manager - ftp_client_service: FTPClient = software_manager.software["FTPClient"] + ftp_client_service: FTPClient = software_manager.software.get("FTPClient") # retrieve backup file from backup server response = ftp_client_service.request_file( diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 23d52342..52655fa4 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -8,7 +8,6 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC -from primaite.simulator.system.services.service import ServiceOperatingState _LOGGER = getLogger(__name__) @@ -53,8 +52,7 @@ class FTPClient(FTPServiceABC): :type: session_id: Optional[str] """ # if client service is down, return error - if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.error("FTP Client is not running") + if not self._can_perform_action(): payload.status_code = FTPStatusCode.ERROR return payload @@ -81,8 +79,7 @@ class FTPClient(FTPServiceABC): :type: is_reattempt: Optional[bool] """ # make sure the service is running before attempting - if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.error(f"FTPClient not running for {self.sys_log.hostname}") + if not self._can_perform_action(): return False # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now @@ -282,8 +279,11 @@ class FTPClient(FTPServiceABC): This helps prevent an FTP request loop - FTP client and servers can exist on the same node. """ + if not self._can_perform_action(): + return False + if payload.status_code is None: - self.sys_log.error(f"FTP Server could not be found - Error Code: {payload.status_code.value}") + self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}") return False self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}") diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index 44d0455f..0278b616 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -6,7 +6,6 @@ from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPS from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC -from primaite.simulator.system.services.service import ServiceOperatingState _LOGGER = getLogger(__name__) @@ -58,8 +57,7 @@ class FTPServer(FTPServiceABC): payload.status_code = FTPStatusCode.ERROR # if server service is down, return error - if self.operating_state != ServiceOperatingState.RUNNING: - self.sys_log.error("FTP Server not running") + if not self._can_perform_action(): return payload self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}") @@ -95,6 +93,9 @@ class FTPServer(FTPServiceABC): self.sys_log.error(f"{payload} is not an FTP packet") return False + if not super().receive(payload=payload, session_id=session_id, **kwargs): + return False + """ Ignore ftp payload if status code is defined. @@ -102,9 +103,6 @@ class FTPServer(FTPServiceABC): prevents an FTP request loop - FTP client and servers can exist on the same node. """ - if not super().receive(payload=payload, session_id=session_id, **kwargs): - return False - if payload.status_code is not None: return False diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 6d6cda86..e60b7700 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -109,8 +109,8 @@ class Service(IOSoftware): """ state = super().describe_state() state["operating_state"] = self.operating_state.value - state["health_state_actual"] = self.health_state_actual - state["health_state_visible"] = self.health_state_visible + state["health_state_actual"] = self.health_state_actual.value + state["health_state_visible"] = self.health_state_visible.value return state def stop(self) -> None: diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index e5f3dccc..afd6cb74 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -119,7 +119,7 @@ class WebServer(Service): if path.startswith("users"): # get data from DatabaseServer - db_client: DatabaseClient = self.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient") # get all users if db_client.query("SELECT"): # query succeeded diff --git a/tests/conftest.py b/tests/conftest.py index c0d05455..1ab07dd8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Tuple, Union import pytest import yaml @@ -12,6 +12,11 @@ from primaite.session.session import PrimaiteSession # from primaite.environment.primaite_env import Primaite # from primaite.primaite_session import PrimaiteSession from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.hardware.nodes.switch import Switch from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port @@ -29,7 +34,7 @@ from primaite import PRIMAITE_PATHS # PrimAITE v3 stuff from primaite.simulator.file_system.file_system import FileSystem -from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.hardware.base import Link, Node class TestService(Service): @@ -122,3 +127,110 @@ def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession: monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) config_path = request.param[0] return TempPrimaiteSession.from_config(config_path=config_path) + + +@pytest.fixture(scope="function") +def client_server() -> Tuple[Computer, Server]: + # Create Computer + computer: Computer = Computer( + hostname="test_computer", + ip_address="192.168.0.1", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, + ) + + # Create Server + server = Server( + hostname="server", ip_address="192.168.0.2", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON + ) + + # Connect Computer and Server + computer_nic = computer.nics[next(iter(computer.nics))] + server_nic = server.nics[next(iter(server.nics))] + link = Link(endpoint_a=computer_nic, endpoint_b=server_nic) + + # Should be linked + assert link.is_up + + return computer, server + + +@pytest.fixture(scope="function") +def example_network() -> Network: + """ + Create the network used for testing. + + Should only contain the nodes and links. + This would act as the base network and services and applications are installed in the relevant test file, + + -------------- -------------- + | client_1 |----- ----| server_1 | + -------------- | -------------- -------------- -------------- | -------------- + ------| switch_1 |------| router_1 |------| switch_2 |------ + -------------- | -------------- -------------- -------------- | -------------- + | client_2 |---- ----| server_2 | + -------------- -------------- + """ + network = Network() + + # Router 1 + router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON) + router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0") + router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0") + + # Switch 1 + switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON) + network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8]) + router_1.enable_port(1) + + # Switch 2 + switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON) + network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8]) + router_1.enable_port(2) + + # Client 1 + client_1 = Computer( + hostname="client_1", + ip_address="192.168.10.21", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + operating_state=NodeOperatingState.ON, + ) + network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) + + # Client 2 + client_2 = Computer( + hostname="client_2", + ip_address="192.168.10.22", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + operating_state=NodeOperatingState.ON, + ) + network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) + + # Domain Controller + server_1 = Server( + hostname="server_1", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, + ) + + network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1]) + + # Database Server + server_2 = Server( + hostname="server_2", + ip_address="192.168.1.14", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, + ) + network.connect(endpoint_b=server_2.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) + + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) + + return network diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 3907ff50..91cf5c1e 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -2,6 +2,7 @@ import tempfile from pathlib import Path +import pytest import yaml from stable_baselines3 import PPO @@ -10,6 +11,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv +# @pytest.mark.skip(reason="no way of currently testing this") def test_sb3_compatibility(): """Test that the Gymnasium environment can be used with an SB3 agent.""" with open(example_config_path(), "r") as f: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index f2b6aa3f..7785e4ae 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -11,6 +11,7 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml" MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" +# @pytest.mark.skip(reason="no way of currently testing this") class TestPrimaiteSession: @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_creating_session(self, temp_primaite_session): diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 81bbfc96..0dc2c031 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -8,13 +8,13 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor def test_data_manipulation(uc2_network): """Tests the UC2 data manipulation scenario end-to-end. Is a work in progress.""" client_1: Computer = uc2_network.get_node_by_hostname("client_1") - db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] + db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") database_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = database_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") db_service.backup_database() diff --git a/tests/integration_tests/network/test_link_connection.py b/tests/integration_tests/network/test_link_connection.py index 0ddf54df..c6aeac24 100644 --- a/tests/integration_tests/network/test_link_connection.py +++ b/tests/integration_tests/network/test_link_connection.py @@ -16,3 +16,9 @@ def test_link_up(): assert nic_a.enabled assert nic_b.enabled assert link.is_up + + +def test_ping_between_computer_and_server(client_server): + computer, server = client_server + + assert computer.ping(target_ip_address=server.nics[next(iter(server.nics))].ip_address) diff --git a/tests/integration_tests/network/test_network_creation.py b/tests/integration_tests/network/test_network_creation.py index 91218068..0af44dbb 100644 --- a/tests/integration_tests/network/test_network_creation.py +++ b/tests/integration_tests/network/test_network_creation.py @@ -2,6 +2,28 @@ import pytest from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import NIC, Node +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.networks import client_server_routed + + +def test_network(example_network): + network: Network = example_network + client_1: Computer = network.get_node_by_hostname("client_1") + client_2: Computer = network.get_node_by_hostname("client_2") + server_1: Server = network.get_node_by_hostname("server_1") + server_2: Server = network.get_node_by_hostname("server_2") + + assert client_1.ping(client_2.ethernet_port[1].ip_address) + assert client_2.ping(client_1.ethernet_port[1].ip_address) + + assert server_1.ping(server_2.ethernet_port[1].ip_address) + assert server_2.ping(server_1.ethernet_port[1].ip_address) + + assert client_1.ping(server_1.ethernet_port[1].ip_address) + assert client_2.ping(server_1.ethernet_port[1].ip_address) + assert client_1.ping(server_2.ethernet_port[1].ip_address) + assert client_2.ping(server_2.ethernet_port[1].ip_address) def test_adding_removing_nodes(): diff --git a/tests/integration_tests/system/test_application_on_node.py b/tests/integration_tests/system/test_application_on_node.py index cce586da..46be5e55 100644 --- a/tests/integration_tests/system/test_application_on_node.py +++ b/tests/integration_tests/system/test_application_on_node.py @@ -18,7 +18,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]: ) computer.software_manager.install(application_class) - app = computer.software_manager.software["TestApplication"] + app = computer.software_manager.software.get("TestApplication") app.run() return app, computer @@ -35,7 +35,7 @@ def test_service_on_offline_node(application_class): ) computer.software_manager.install(application_class) - app: Application = computer.software_manager.software["TestApplication"] + app: Application = computer.software_manager.software.get("TestApplication") computer.power_off() diff --git a/tests/integration_tests/system/test_database_on_node.py b/tests/integration_tests/system/test_database_on_node.py index ef2b2956..98c8c87b 100644 --- a/tests/integration_tests/system/test_database_on_node.py +++ b/tests/integration_tests/system/test_database_on_node.py @@ -10,10 +10,10 @@ from primaite.simulator.system.services.service import ServiceOperatingState def test_database_client_server_connection(uc2_network): web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") assert len(db_service.connections) == 1 @@ -23,10 +23,10 @@ def test_database_client_server_connection(uc2_network): def test_database_client_server_correct_password(uc2_network): web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") db_client.disconnect() @@ -40,10 +40,10 @@ def test_database_client_server_correct_password(uc2_network): def test_database_client_server_incorrect_password(uc2_network): web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") db_client.disconnect() db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321") @@ -56,7 +56,7 @@ def test_database_client_server_incorrect_password(uc2_network): def test_database_client_query(uc2_network): """Tests DB query across the network returns HTTP status 200 and date.""" web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") assert db_client.connected @@ -66,13 +66,13 @@ def test_database_client_query(uc2_network): def test_create_database_backup(uc2_network): """Run the backup_database method and check if the FTP server has the relevant file.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") # back up should be created assert db_service.backup_database() is True backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"] + ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer") # backup file should exist in the backup server assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None @@ -81,7 +81,7 @@ def test_create_database_backup(uc2_network): def test_restore_backup(uc2_network): """Run the restore_backup method and check if the backup is properly restored.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") # create a back up assert db_service.backup_database() is True @@ -100,13 +100,13 @@ def test_restore_backup(uc2_network): def test_database_client_cannot_query_offline_database_server(uc2_network): """Tests DB query across the network returns HTTP status 404 when db server is offline.""" db_server: Server = uc2_network.get_node_by_hostname("database_server") - db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") assert db_server.operating_state is NodeOperatingState.ON assert db_service.operating_state is ServiceOperatingState.RUNNING web_server: Server = uc2_network.get_node_by_hostname("web_server") - db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") assert db_client.connected assert db_client.query("SELECT") is True diff --git a/tests/integration_tests/system/test_dns_client_server.py b/tests/integration_tests/system/test_dns_client_server.py index 81a223ef..a54bf23f 100644 --- a/tests/integration_tests/system/test_dns_client_server.py +++ b/tests/integration_tests/system/test_dns_client_server.py @@ -1,3 +1,8 @@ +from ipaddress import IPv4Address +from typing import Tuple + +import pytest + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server @@ -6,12 +11,31 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState -def test_dns_client_server(uc2_network): - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller") +@pytest.fixture(scope="function") +def dns_client_and_dns_server(client_server) -> Tuple[DNSClient, Computer, DNSServer, Server]: + computer, server = client_server - dns_client: DNSClient = client_1.software_manager.software["DNSClient"] - dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"] + # Install DNS Client on computer + computer.software_manager.install(DNSClient) + dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + dns_client.start() + # set server as DNS Server + dns_client.dns_server = IPv4Address(server.nics.get(next(iter(server.nics))).ip_address) + + # Install DNS Server on server + server.software_manager.install(DNSServer) + dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server.start() + # register arcd.com as a domain + dns_server.dns_register( + domain_name="arcd.com", domain_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address) + ) + + return dns_client, computer, dns_server, server + + +def test_dns_client_server(dns_client_and_dns_server): + dns_client, computer, dns_server, server = dns_client_and_dns_server assert dns_client.operating_state == ServiceOperatingState.RUNNING assert dns_server.operating_state == ServiceOperatingState.RUNNING @@ -29,12 +53,8 @@ def test_dns_client_server(uc2_network): assert len(dns_client.dns_cache) == 1 -def test_dns_client_requests_offline_dns_server(uc2_network): - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller") - - dns_client: DNSClient = client_1.software_manager.software["DNSClient"] - dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"] +def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server): + dns_client, computer, dns_server, server = dns_client_and_dns_server assert dns_client.operating_state == ServiceOperatingState.RUNNING assert dns_server.operating_state == ServiceOperatingState.RUNNING @@ -48,12 +68,12 @@ def test_dns_client_requests_offline_dns_server(uc2_network): assert len(dns_client.dns_cache) == 1 dns_client.dns_cache = {} - domain_controller.power_off() + server.power_off() - for i in range(domain_controller.shut_down_duration + 1): - uc2_network.apply_timestep(timestep=i) + for i in range(server.shut_down_duration + 1): + server.apply_timestep(timestep=i) - assert domain_controller.operating_state == NodeOperatingState.OFF + assert server.operating_state == NodeOperatingState.OFF assert dns_server.operating_state == ServiceOperatingState.STOPPED # this time it should not cache because dns server is not online diff --git a/tests/integration_tests/system/test_ftp_client_server.py b/tests/integration_tests/system/test_ftp_client_server.py index b2cdbc06..1a6a8f41 100644 --- a/tests/integration_tests/system/test_ftp_client_server.py +++ b/tests/integration_tests/system/test_ftp_client_server.py @@ -1,4 +1,7 @@ from ipaddress import IPv4Address +from typing import Tuple + +import pytest from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server @@ -7,18 +10,31 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.service import ServiceOperatingState -def test_ftp_client_store_file_in_server(uc2_network): +@pytest.fixture(scope="function") +def ftp_client_and_ftp_server(client_server) -> Tuple[FTPClient, Computer, FTPServer, Server]: + computer, server = client_server + + # Install FTP Client service on computer + computer.software_manager.install(FTPClient) + ftp_client: FTPClient = computer.software_manager.software.get("FTPClient") + ftp_client.start() + + # Install FTP Server service on server + server.software_manager.install(FTPServer) + ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server.start() + + return ftp_client, computer, ftp_server, server + + +def test_ftp_client_store_file_in_server(ftp_client_and_ftp_server): """ Test checks to see if the client is able to store files in the backup server. """ - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - - ftp_client: FTPClient = client_1.software_manager.software["FTPClient"] - ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"] + ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server assert ftp_client.operating_state == ServiceOperatingState.RUNNING - assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING + assert ftp_server.operating_state == ServiceOperatingState.RUNNING # create file on ftp client ftp_client.file_system.create_file(file_name="test_file.txt") @@ -28,61 +44,53 @@ def test_ftp_client_store_file_in_server(uc2_network): src_file_name="test_file.txt", dest_folder_name="client_1_backup", dest_file_name="test_file.txt", - dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address, + dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address, ) - assert ftp_server_service.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt") + assert ftp_server.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt") -def test_ftp_client_retrieve_file_from_server(uc2_network): +def test_ftp_client_retrieve_file_from_server(ftp_client_and_ftp_server): """ Test checks to see if the client is able to retrieve files from the backup server. """ - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - - ftp_client: FTPClient = client_1.software_manager.software["FTPClient"] - ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"] + ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server assert ftp_client.operating_state == ServiceOperatingState.RUNNING - assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING + assert ftp_server.operating_state == ServiceOperatingState.RUNNING # create file on ftp server - ftp_server_service.file_system.create_file(file_name="test_file.txt", folder_name="file_share") + ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share") assert ftp_client.request_file( src_folder_name="file_share", src_file_name="test_file.txt", dest_folder_name="downloads", dest_file_name="test_file.txt", - dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address, + dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address, ) # client should have retrieved the file assert ftp_client.file_system.get_file(folder_name="downloads", file_name="test_file.txt") -def test_ftp_client_tries_to_connect_to_offline_server(uc2_network): +def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server): """Test checks to make sure that the client can't do anything when the server is offline.""" - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - backup_server: Server = uc2_network.get_node_by_hostname("backup_server") - - ftp_client: FTPClient = client_1.software_manager.software["FTPClient"] - ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"] + ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server assert ftp_client.operating_state == ServiceOperatingState.RUNNING - assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING + assert ftp_server.operating_state == ServiceOperatingState.RUNNING # create file on ftp server - ftp_server_service.file_system.create_file(file_name="test_file.txt", folder_name="file_share") + ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share") - backup_server.power_off() + server.power_off() - for i in range(backup_server.shut_down_duration + 1): - uc2_network.apply_timestep(timestep=i) + for i in range(server.shut_down_duration + 1): + server.apply_timestep(timestep=i) assert ftp_client.operating_state == ServiceOperatingState.RUNNING - assert ftp_server_service.operating_state == ServiceOperatingState.STOPPED + assert ftp_server.operating_state == ServiceOperatingState.STOPPED assert ( ftp_client.request_file( @@ -90,7 +98,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(uc2_network): src_file_name="test_file.txt", dest_folder_name="downloads", dest_file_name="test_file.txt", - dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address, + dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address, ) is False ) diff --git a/tests/integration_tests/system/test_service_on_node.py b/tests/integration_tests/system/test_service_on_node.py index 9480c358..aab1e4da 100644 --- a/tests/integration_tests/system/test_service_on_node.py +++ b/tests/integration_tests/system/test_service_on_node.py @@ -17,7 +17,7 @@ def populated_node( ) server.software_manager.install(service_class) - service = server.software_manager.software["TestService"] + service = server.software_manager.software.get("TestService") service.start() return server, service @@ -34,7 +34,7 @@ def test_service_on_offline_node(service_class): ) computer.software_manager.install(service_class) - service: Service = computer.software_manager.software["TestService"] + service: Service = computer.software_manager.software.get("TestService") computer.power_off() diff --git a/tests/integration_tests/system/test_web_client_server.py b/tests/integration_tests/system/test_web_client_server.py index 3ee1e3ed..b3d2e891 100644 --- a/tests/integration_tests/system/test_web_client_server.py +++ b/tests/integration_tests/system/test_web_client_server.py @@ -1,104 +1,118 @@ +from typing import Tuple + +import pytest + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.protocols.http import HttpStatusCode from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.web_server.web_server import WebServer -def test_web_page_home_page(uc2_network): - """Test to see if the browser is able to open the main page of the web server.""" - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() - web_client.target_url = "http://arcd.com/" - assert web_client.operating_state == ApplicationOperatingState.RUNNING +@pytest.fixture(scope="function") +def web_client_and_web_server(client_server) -> Tuple[WebBrowser, Computer, WebServer, Server]: + computer, server = client_server - assert web_client.get_webpage() is True + # Install Web Browser on computer + computer.software_manager.install(WebBrowser) + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser.run() - # latest reponse should have status code 200 - assert web_client.latest_response is not None - assert web_client.latest_response.status_code == HttpStatusCode.OK + # Install DNS Client service on computer + computer.software_manager.install(DNSClient) + dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + # set dns server + dns_client.dns_server = server.nics[next(iter(server.nics))].ip_address + + # Install Web Server service on server + server.software_manager.install(WebServer) + web_server_service: WebServer = server.software_manager.software.get("WebServer") + web_server_service.start() + + # Install DNS Server service on server + server.software_manager.install(DNSServer) + dns_server: DNSServer = server.software_manager.software.get("DNSServer") + # register arcd.com to DNS + dns_server.dns_register(domain_name="arcd.com", domain_ip_address=server.nics[next(iter(server.nics))].ip_address) + + return web_browser, computer, web_server_service, server -def test_web_page_get_users_page_request_with_domain_name(uc2_network): +def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_server): """Test to see if the client can handle requests with domain names""" - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() - assert web_client.operating_state == ApplicationOperatingState.RUNNING - web_client.target_url = "http://arcd.com/users/" + web_browser_app, computer, web_server_service, server = web_client_and_web_server - assert web_client.get_webpage() is True + web_server_ip = server.nics.get(next(iter(server.nics))).ip_address + web_browser_app.target_url = f"http://arcd.com/" + assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING + + assert web_browser_app.get_webpage() is True # latest response should have status code 200 - assert web_client.latest_response is not None - assert web_client.latest_response.status_code == HttpStatusCode.OK + assert web_browser_app.latest_response is not None + assert web_browser_app.latest_response.status_code == HttpStatusCode.OK -def test_web_page_get_users_page_request_with_ip_address(uc2_network): +def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_server): """Test to see if the client can handle requests that use ip_address.""" - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() + web_browser_app, computer, web_server_service, server = web_client_and_web_server - web_server: Server = uc2_network.get_node_by_hostname("web_server") + web_server_ip = server.nics.get(next(iter(server.nics))).ip_address + web_browser_app.target_url = f"http://{web_server_ip}/" + assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING - web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address - web_client.target_url = f"http://{web_server_ip}/users/" - assert web_client.operating_state == ApplicationOperatingState.RUNNING - - assert web_client.get_webpage() is True + assert web_browser_app.get_webpage() is True # latest response should have status code 200 - assert web_client.latest_response is not None - assert web_client.latest_response.status_code == HttpStatusCode.OK + assert web_browser_app.latest_response is not None + assert web_browser_app.latest_response.status_code == HttpStatusCode.OK -def test_web_page_request_from_shut_down_server(uc2_network): +def test_web_page_request_from_shut_down_server(web_client_and_web_server): """Test to see that the web server does not respond when the server is off.""" - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() + web_browser_app, computer, web_server_service, server = web_client_and_web_server - web_server: Server = uc2_network.get_node_by_hostname("web_server") + web_server_ip = server.nics.get(next(iter(server.nics))).ip_address + web_browser_app.target_url = f"http://arcd.com/" + assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING - assert web_client.operating_state == ApplicationOperatingState.RUNNING - - assert web_client.get_webpage("http://arcd.com/users/") is True + assert web_browser_app.get_webpage() is True # latest response should have status code 200 - assert web_client.latest_response.status_code == HttpStatusCode.OK + assert web_browser_app.latest_response is not None + assert web_browser_app.latest_response.status_code == HttpStatusCode.OK - web_server.power_off() + server.power_off() - for i in range(web_server.shut_down_duration + 1): - uc2_network.apply_timestep(timestep=i) + server.power_off() + + for i in range(server.shut_down_duration + 1): + server.apply_timestep(timestep=i) # node should be off - assert web_server.operating_state is NodeOperatingState.OFF + assert server.operating_state is NodeOperatingState.OFF - assert web_client.get_webpage("http://arcd.com/users/") is False - assert web_client.latest_response.status_code == HttpStatusCode.NOT_FOUND + assert web_browser_app.get_webpage() is False + assert web_browser_app.latest_response.status_code == HttpStatusCode.NOT_FOUND -def test_web_page_request_from_closed_web_browser(uc2_network): - client_1: Computer = uc2_network.get_node_by_hostname("client_1") - web_client: WebBrowser = client_1.software_manager.software["WebBrowser"] - web_client.run() +def test_web_page_request_from_closed_web_browser(web_client_and_web_server): + web_browser_app, computer, web_server_service, server = web_client_and_web_server - web_server: Server = uc2_network.get_node_by_hostname("web_server") - - assert web_client.operating_state == ApplicationOperatingState.RUNNING - - assert web_client.get_webpage("http://arcd.com/users/") is True + assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING + web_browser_app.target_url = f"http://arcd.com/" + assert web_browser_app.get_webpage() is True # latest response should have status code 200 - assert web_client.latest_response.status_code == HttpStatusCode.OK + assert web_browser_app.latest_response.status_code == HttpStatusCode.OK - web_client.close() + web_browser_app.close() # node should be off - assert web_client.operating_state is ApplicationOperatingState.CLOSED + assert web_browser_app.operating_state is ApplicationOperatingState.CLOSED - assert web_client.get_webpage("http://arcd.com/users/") is False + assert web_browser_app.get_webpage() is False diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py new file mode 100644 index 00000000..a4ef3d52 --- /dev/null +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -0,0 +1,108 @@ +from ipaddress import IPv4Address +from typing import Tuple + +import pytest + +from primaite.simulator.network.hardware.base import Link +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.hardware.nodes.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.web_server.web_server import WebServer + + +@pytest.fixture(scope="function") +def web_client_web_server_database(example_network) -> Tuple[Computer, Server, Server]: + # add rules to network router + router_1: Router = example_network.get_node_by_hostname("router_1") + router_1.acl.add_rule( + action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0 + ) + + # Allow DNS requests + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) + + # Allow FTP requests + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2) + + # Open port 80 for web server + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + + # Create Computer + computer: Computer = example_network.get_node_by_hostname("client_1") + + # Create Web Server + web_server: Server = example_network.get_node_by_hostname("server_1") + + # Create Database Server + db_server = example_network.get_node_by_hostname("server_2") + + # Get the NICs + computer_nic = computer.nics[next(iter(computer.nics))] + server_nic = web_server.nics[next(iter(web_server.nics))] + db_server_nic = db_server.nics[next(iter(db_server.nics))] + + # Connect Computer and Server + link_computer_server = Link(endpoint_a=computer_nic, endpoint_b=server_nic) + # Should be linked + assert link_computer_server.is_up + + # Connect database server and web server + link_server_db = Link(endpoint_a=server_nic, endpoint_b=db_server_nic) + # Should be linked + assert link_computer_server.is_up + assert link_server_db.is_up + + # Install DatabaseService on db server + db_server.software_manager.install(DatabaseService) + db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService") + db_service.start() + + # Install Web Browser on computer + computer.software_manager.install(WebBrowser) + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser.target_url = "http://arcd.com/users/" + web_browser.run() + + # Install DNS Client service on computer + computer.software_manager.install(DNSClient) + dns_client: DNSClient = computer.software_manager.software.get("DNSClient") + # set dns server + dns_client.dns_server = web_server.nics[next(iter(web_server.nics))].ip_address + + # Install Web Server service on web server + web_server.software_manager.install(WebServer) + web_server_service: WebServer = web_server.software_manager.software.get("WebServer") + web_server_service.start() + + # Install DNS Server service on web server + web_server.software_manager.install(DNSServer) + dns_server: DNSServer = web_server.software_manager.software.get("DNSServer") + # register arcd.com to DNS + dns_server.dns_register( + domain_name="arcd.com", domain_ip_address=web_server.nics[next(iter(web_server.nics))].ip_address + ) + + # Install DatabaseClient service on web server + web_server.software_manager.install(DatabaseClient) + db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") + db_client.server_ip_address = IPv4Address(db_server_nic.ip_address) # set IP address of Database Server + db_client.run() + assert dns_client.check_domain_exists("arcd.com") + assert db_client.connect() + + return computer, web_server, db_server + + +def test_web_client_requests_users(web_client_web_server_database): + computer, web_server, db_server = web_client_web_server_database + + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + + assert web_browser.get_webpage() diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 96c34996..01ad3871 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -1,18 +1,140 @@ """Test the account module of the simulator.""" +import pytest + from primaite.simulator.domain.account import Account, AccountType -def test_account_serialise(): +@pytest.fixture(scope="function") +def account() -> Account: + acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER) + acct.set_original_state() + return acct + + +def test_original_state(account): + """Test the original state - see if it resets properly""" + account.log_on() + account.log_off() + account.disable() + + state = account.describe_state() + assert state["num_logons"] is 1 + assert state["num_logoffs"] is 1 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is False + + account.reset_component_for_episode(episode=1) + state = account.describe_state() + assert state["num_logons"] is 0 + assert state["num_logoffs"] is 0 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is True + + account.log_on() + account.log_off() + account.disable() + account.set_original_state() + + account.log_on() + state = account.describe_state() + assert state["num_logons"] is 2 + + account.reset_component_for_episode(episode=2) + state = account.describe_state() + assert state["num_logons"] is 1 + assert state["num_logoffs"] is 1 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is False + + +def test_enable(account): + """Should enable the account.""" + account.enabled = False + account.enable() + assert account.enabled is True + + +def test_disable(account): + """Should disable the account.""" + account.enabled = True + account.disable() + assert account.enabled is False + + +def test_log_on_increments(account): + """Should increase the log on value by 1.""" + account.num_logons = 0 + account.log_on() + assert account.num_logons is 1 + + +def test_log_off_increments(account): + """Should increase the log on value by 1.""" + account.num_logoffs = 0 + account.log_off() + assert account.num_logoffs is 1 + + +def test_account_serialise(account): """Test that an account can be serialised. If pydantic throws error then this test fails.""" - acct = Account(username="Jake", password="JakePass1!", account_type=AccountType.USER) - serialised = acct.model_dump_json() + serialised = account.model_dump_json() print(serialised) -def test_account_deserialise(): +def test_account_deserialise(account): """Test that an account can be deserialised. The test fails if pydantic throws an error.""" acct_json = ( '{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,' - '"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"request_manager":null}' + '"username":"Jake","password":"totally_hashed_password","account_type":2,"status":2,"request_manager":null}' ) - acct = Account.model_validate_json(acct_json) + assert Account.model_validate_json(acct_json) + + +def test_describe_state(account): + state = account.describe_state() + assert state["num_logons"] is 0 + assert state["num_logoffs"] is 0 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is True + + account.log_on() + state = account.describe_state() + assert state["num_logons"] is 1 + assert state["num_logoffs"] is 0 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is True + + account.log_off() + state = account.describe_state() + assert state["num_logons"] is 1 + assert state["num_logoffs"] is 1 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is True + + account.disable() + state = account.describe_state() + assert state["num_logons"] is 1 + assert state["num_logoffs"] is 1 + assert state["num_group_changes"] is 0 + assert state["username"] is "Jake" + assert state["password"] is "totally_hashed_password" + assert state["account_type"] is AccountType.USER.value + assert state["enabled"] is False diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 4defc80c..9366d173 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -185,6 +185,38 @@ def test_get_file(file_system): file_system.show(full=True) +def test_reset_file_system(file_system): + # file and folder that existed originally + file_system.create_file(file_name="test_file.zip") + file_system.create_folder(folder_name="test_folder") + file_system.set_original_state() + + # create a new file + file_system.create_file(file_name="new_file.txt") + + # create a new folder + file_system.create_folder(folder_name="new_folder") + + # delete the file that existed originally + file_system.delete_file(folder_name="root", file_name="test_file.zip") + assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None + + # delete the folder that existed originally + file_system.delete_folder(folder_name="test_folder") + assert file_system.get_folder(folder_name="test_folder") is None + + # reset + file_system.reset_component_for_episode(episode=1) + + # deleted original file and folder should be back + assert file_system.get_file(folder_name="root", file_name="test_file.zip") + assert file_system.get_folder(folder_name="test_folder") + + # new file and folder should be removed + assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None + assert file_system.get_folder(folder_name="new_folder") is None + + @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py new file mode 100644 index 00000000..d2d0e52c --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/_hardware/nodes/test_switch.py @@ -0,0 +1,17 @@ +import pytest + +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.switch import Switch + + +@pytest.fixture(scope="function") +def switch() -> Switch: + switch: Switch = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON) + switch.show() + return switch + + +def test_describe_state(switch): + state = switch.describe_state() + assert len(state.get("ports")) is 8 + assert state.get("num_ports") is 8 diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 66bd59a9..021d6777 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -3,6 +3,66 @@ import json import pytest from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService + + +@pytest.fixture(scope="function") +def network(example_network) -> Network: + assert len(example_network.routers) is 1 + assert len(example_network.switches) is 2 + assert len(example_network.computers) is 2 + assert len(example_network.servers) is 2 + + example_network.set_original_state() + example_network.show() + + return example_network + + +def test_describe_state(network): + """Test that describe state works.""" + state = network.describe_state() + + assert len(state["nodes"]) is 7 + assert len(state["links"]) is 6 + + +def test_reset_network(network): + """ + Test that the network is properly reset. + + TODO: make sure that once implemented - any installed/uninstalled services, processes, apps, + etc are also removed/reinstalled + + """ + state_before = network.describe_state() + + client_1: Computer = network.get_node_by_hostname("client_1") + server_1: Computer = network.get_node_by_hostname("server_1") + + assert client_1.operating_state is NodeOperatingState.ON + assert server_1.operating_state is NodeOperatingState.ON + + client_1.power_off() + assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN + + server_1.power_off() + assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN + + assert network.describe_state() is not state_before + + network.reset_component_for_episode(episode=1) + + assert client_1.operating_state is NodeOperatingState.ON + assert server_1.operating_state is NodeOperatingState.ON + + assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps( + state_before, sort_keys=True, indent=2 + ) def test_creating_container(): @@ -10,11 +70,46 @@ def test_creating_container(): net = Network() assert net.nodes == {} assert net.links == {} + net.show() -@pytest.mark.skip(reason="Skipping until we tackle serialisation") -def test_describe_state(): - """Check that we can describe network state without raising errors, and that the result is JSON serialisable.""" - net = Network() - state = net.describe_state() - json.dumps(state) # if this function call raises an error, the test fails, state was not JSON-serialisable +def test_apply_timestep_to_nodes(network): + """Calling apply_timestep on the network should apply to the nodes within it.""" + client_1: Computer = network.get_node_by_hostname("client_1") + assert client_1.operating_state is NodeOperatingState.ON + + client_1.power_off() + + for i in range(client_1.shut_down_duration + 1): + network.apply_timestep(timestep=i) + + assert client_1.operating_state is NodeOperatingState.OFF + + +def test_removing_node_that_does_not_exist(network): + """Node that does not exist on network should not affect existing nodes.""" + assert len(network.nodes) is 7 + + network.remove_node(Node(hostname="new_node")) + assert len(network.nodes) is 7 + + +def test_remove_node(network): + """Remove node should remove the correct node.""" + assert len(network.nodes) is 7 + + client_1: Computer = network.get_node_by_hostname("client_1") + network.remove_node(client_1) + + assert network.get_node_by_hostname("client_1") is None + assert len(network.nodes) is 6 + + +def test_remove_link(network): + """Remove link should remove the correct link.""" + assert len(network.links) is 6 + link: Link = network.links.get(next(iter(network.links))) + + network.remove_link(link) + assert len(network.links) is 5 + assert network.links.get(link.uuid) is None diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_utils.py b/tests/unit_tests/_primaite/_simulator/_network/test_utils.py new file mode 100644 index 00000000..a0c1da45 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_network/test_utils.py @@ -0,0 +1,11 @@ +from primaite.simulator.network.utils import convert_bytes_to_megabits, convert_megabits_to_bytes + + +def test_convert_bytes_to_megabits(): + assert round(convert_bytes_to_megabits(B=131072), 5) == float(1) + assert round(convert_bytes_to_megabits(B=69420), 5) == float(0.52963) + + +def test_convert_megabits_to_bytes(): + assert round(convert_megabits_to_bytes(Mbits=1), 5) == float(131072) + assert round(convert_megabits_to_bytes(Mbits=float(0.52963)), 5) == float(69419.66336) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py new file mode 100644 index 00000000..59d44561 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_database_client.py @@ -0,0 +1,122 @@ +from ipaddress import IPv4Address +from typing import Tuple, Union + +import pytest + +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.system.applications.application import ApplicationOperatingState +from primaite.simulator.system.applications.database_client import DatabaseClient + + +@pytest.fixture(scope="function") +def database_client_on_computer() -> Tuple[DatabaseClient, Computer]: + computer = Computer( + hostname="db_node", ip_address="192.168.0.1", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON + ) + computer.software_manager.install(DatabaseClient) + + database_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient") + database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) + database_client.run() + return database_client, computer + + +def test_creation(database_client_on_computer): + database_client, computer = database_client_on_computer + database_client.describe_state() + + +def test_connect_when_client_is_closed(database_client_on_computer): + """Database client should not connect when it is not running.""" + database_client, computer = database_client_on_computer + + database_client.close() + assert database_client.operating_state is ApplicationOperatingState.CLOSED + + assert database_client.connect() is False + + +def test_connect_to_database_fails_on_reattempt(database_client_on_computer): + """Database client should return False when the attempt to connect fails.""" + database_client, computer = database_client_on_computer + + database_client.connected = False + assert database_client._connect(server_ip_address=IPv4Address("192.168.0.1"), is_reattempt=True) is False + + +def test_disconnect_when_client_is_closed(database_client_on_computer): + """Database client disconnect should not do anything when it is not running.""" + database_client, computer = database_client_on_computer + + database_client.connected = True + assert database_client.server_ip_address is not None + + database_client.close() + assert database_client.operating_state is ApplicationOperatingState.CLOSED + + database_client.disconnect() + + assert database_client.connected is True + assert database_client.server_ip_address is not None + + +def test_disconnect(database_client_on_computer): + """Database client should set connected to False and remove the database server ip address.""" + database_client, computer = database_client_on_computer + + database_client.connected = True + + assert database_client.operating_state is ApplicationOperatingState.RUNNING + assert database_client.server_ip_address is not None + + database_client.disconnect() + + assert database_client.connected is False + assert database_client.server_ip_address is None + + +def test_query_when_client_is_closed(database_client_on_computer): + """Database client should return False when it is not running.""" + database_client, computer = database_client_on_computer + + database_client.close() + assert database_client.operating_state is ApplicationOperatingState.CLOSED + + assert database_client.query(sql="test") is False + + +def test_query_failed_reattempt(database_client_on_computer): + """Database client query should return False if the reattempt fails.""" + database_client, computer = database_client_on_computer + + def return_false(): + return False + + database_client.connect = return_false + + database_client.connected = False + assert database_client.query(sql="test", is_reattempt=True) is False + + +def test_query_fail_to_connect(database_client_on_computer): + """Database client query should return False if the connect attempt fails.""" + database_client, computer = database_client_on_computer + + def return_false(): + return False + + database_client.connect = return_false + database_client.connected = False + + assert database_client.query(sql="test") is False + + +def test_client_receives_response_when_closed(database_client_on_computer): + """Database client receive should return False when it is closed.""" + database_client, computer = database_client_on_computer + + database_client.close() + assert database_client.operating_state is ApplicationOperatingState.CLOSED + + database_client.receive(payload={}, session_id="") diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py index b2724369..dc8f7419 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/test_web_browser.py @@ -1,39 +1,66 @@ +from typing import Tuple + import pytest +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser @pytest.fixture(scope="function") -def web_client() -> Computer: - node = Computer( - hostname="web_client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" +def web_browser() -> WebBrowser: + computer = Computer( + hostname="web_client", + ip_address="192.168.1.11", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, ) - return node + # Web Browser should be pre-installed in computer + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser.run() + assert web_browser.operating_state is ApplicationOperatingState.RUNNING + return web_browser -def test_create_web_client(web_client): - assert web_client is not None - web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"] +def test_create_web_client(): + computer = Computer( + hostname="web_client", + ip_address="192.168.1.11", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, + ) + # Web Browser should be pre-installed in computer + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") assert web_browser.name is "WebBrowser" assert web_browser.port is Port.HTTP assert web_browser.protocol is IPProtocol.TCP -def test_receive_invalid_payload(web_client): - web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"] - +def test_receive_invalid_payload(web_browser): assert web_browser.receive(payload={}) is False -def test_receive_payload(web_client): +def test_receive_payload(web_browser): payload = HttpResponsePacket(status_code=HttpStatusCode.OK) - web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"] assert web_browser.latest_response is None web_browser.receive(payload=payload) assert web_browser.latest_response is not None + + +def test_invalid_target_url(web_browser): + # none value target url + web_browser.target_url = None + assert web_browser.get_webpage() is False + + +def test_non_existent_target_url(web_browser): + web_browser.target_url = "http://192.168.255.255" + assert web_browser.get_webpage() is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py index 3b1e4aa4..2c4826bf 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/_red_services/test_data_manipulation_bot.py @@ -19,11 +19,11 @@ def dm_client() -> Node: @pytest.fixture def dm_bot(dm_client) -> DataManipulationBot: - return dm_client.software_manager.software["DataManipulationBot"] + return dm_client.software_manager.software.get("DataManipulationBot") def test_create_dm_bot(dm_client): - data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"] + data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" assert data_manipulation_bot.port == Port.POSTGRES_SERVER diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py index 7662fbff..4d96b584 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_database.py @@ -8,7 +8,7 @@ from primaite.simulator.system.services.database.database_service import Databas def database_server() -> Node: node = Node(hostname="db_node") node.software_manager.install(DatabaseService) - node.software_manager.software["DatabaseService"].start() + node.software_manager.software.get("DatabaseService").start() return node diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py similarity index 63% rename from tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py rename to tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py index 2b4082d9..2bcb512d 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_client.py @@ -5,28 +5,13 @@ import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer -from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import ServiceOperatingState -@pytest.fixture(scope="function") -def dns_server() -> Node: - node = Server( - hostname="dns_server", - ip_address="192.168.1.10", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - operating_state=NodeOperatingState.ON, - ) - node.software_manager.install(software_class=DNSServer) - return node - - @pytest.fixture(scope="function") def dns_client() -> Node: node = Computer( @@ -39,24 +24,16 @@ def dns_client() -> Node: return node -def test_create_dns_server(dns_server): - assert dns_server is not None - dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] - assert dns_server_service.name is "DNSServer" - assert dns_server_service.port is Port.DNS - assert dns_server_service.protocol is IPProtocol.TCP - - def test_create_dns_client(dns_client): assert dns_client is not None - dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client_service.name is "DNSClient" assert dns_client_service.port is Port.DNS assert dns_client_service.protocol is IPProtocol.TCP def test_dns_client_add_domain_to_cache_when_not_running(dns_client): - dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") assert dns_client.operating_state is NodeOperatingState.OFF assert dns_client_service.operating_state is ServiceOperatingState.STOPPED @@ -69,7 +46,7 @@ def test_dns_client_add_domain_to_cache_when_not_running(dns_client): def test_dns_client_check_domain_exists_when_not_running(dns_client): dns_client.operating_state = NodeOperatingState.ON - dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") dns_client_service.start() assert dns_client.operating_state is NodeOperatingState.ON @@ -93,22 +70,10 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client): assert dns_client_service.check_domain_exists("test.com") is False -def test_dns_server_domain_name_registration(dns_server): - """Test to check if the domain name registration works.""" - dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] - - # register the web server in the domain controller - dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) - - # return none for an unknown domain - assert dns_server_service.dns_lookup("fake-domain.com") is None - assert dns_server_service.dns_lookup("real-domain.com") is not None - - def test_dns_client_check_domain_in_cache(dns_client): """Test to make sure that the check_domain_in_cache returns the correct values.""" dns_client.operating_state = NodeOperatingState.ON - dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") dns_client_service.start() # add a domain to the dns client cache @@ -118,29 +83,9 @@ def test_dns_client_check_domain_in_cache(dns_client): assert dns_client_service.check_domain_exists("real-domain.com") is True -def test_dns_server_receive(dns_server): - """Test to make sure that the DNS Server correctly responds to a DNS Client request.""" - dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"] - - # register the web server in the domain controller - dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) - - assert ( - dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com"))) - is False - ) - - assert ( - dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com"))) - is True - ) - - dns_server_service.show() - - def test_dns_client_receive(dns_client): """Test to make sure the DNS Client knows how to deal with request responses.""" - dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] + dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") dns_client_service.receive( payload=DNSPacket( @@ -151,3 +96,9 @@ def test_dns_client_receive(dns_client): # domain name should be saved to cache assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12") + + +def test_dns_client_receive_non_dns_payload(dns_client): + dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient") + + assert dns_client_service.receive(payload=None) is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py new file mode 100644 index 00000000..eb042c92 --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_dns_server.py @@ -0,0 +1,64 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.server import Server +from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.dns.dns_server import DNSServer + + +@pytest.fixture(scope="function") +def dns_server() -> Node: + node = Server( + hostname="dns_server", + ip_address="192.168.1.10", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, + ) + node.software_manager.install(software_class=DNSServer) + return node + + +def test_create_dns_server(dns_server): + assert dns_server is not None + dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") + assert dns_server_service.name is "DNSServer" + assert dns_server_service.port is Port.DNS + assert dns_server_service.protocol is IPProtocol.TCP + + +def test_dns_server_domain_name_registration(dns_server): + """Test to check if the domain name registration works.""" + dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + # return none for an unknown domain + assert dns_server_service.dns_lookup("fake-domain.com") is None + assert dns_server_service.dns_lookup("real-domain.com") is not None + + +def test_dns_server_receive(dns_server): + """Test to make sure that the DNS Server correctly responds to a DNS Client request.""" + dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer") + + # register the web server in the domain controller + dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12")) + + assert ( + dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com"))) + is False + ) + + assert ( + dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com"))) + is True + ) + + dns_server_service.show() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py new file mode 100644 index 00000000..134f82bd --- /dev/null +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -0,0 +1,122 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.hardware.base import Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.computer import Computer +from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.service import ServiceOperatingState + + +@pytest.fixture(scope="function") +def ftp_client() -> Node: + node = Computer( + hostname="ftp_client", + ip_address="192.168.1.11", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + operating_state=NodeOperatingState.ON, + ) + return node + + +def test_create_ftp_client(ftp_client): + assert ftp_client is not None + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + assert ftp_client_service.name is "FTPClient" + assert ftp_client_service.port is Port.FTP + assert ftp_client_service.protocol is IPProtocol.TCP + + +def test_ftp_client_store_file(ftp_client): + """Test to make sure the FTP Client knows how to deal with request responses.""" + assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None + + response: FTPPacket = FTPPacket( + ftp_command=FTPCommand.STOR, + ftp_command_args={ + "dest_folder_name": "downloads", + "dest_file_name": "file.txt", + "file_size": 24, + }, + packet_payload_size=24, + status_code=FTPStatusCode.OK, + ) + + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service.receive(response) + + assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") + + +def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): + """Method _process_ftp_command should return false if service is not running.""" + payload: FTPPacket = FTPPacket( + ftp_command=FTPCommand.PORT, + ftp_command_args=Port.FTP, + status_code=FTPStatusCode.OK, + ) + + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client_service.stop() + assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED + assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR + + +def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client): + """Method send_file should return false if no file to send.""" + assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None + + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + assert ftp_client_service.operating_state is ServiceOperatingState.RUNNING + assert ( + ftp_client_service.send_file( + dest_ip_address=IPv4Address("192.168.1.1"), + src_folder_name="root", + src_file_name="test.txt", + dest_folder_name="root", + dest_file_name="text.txt", + ) + is False + ) + + +def test_offline_ftp_client_receives_request(ftp_client): + """Receive should return false if the node the ftp client is installed on is offline.""" + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + ftp_client.power_off() + + for i in range(ftp_client.shut_down_duration + 1): + ftp_client.apply_timestep(timestep=i) + + assert ftp_client.operating_state is NodeOperatingState.OFF + assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED + + payload: FTPPacket = FTPPacket( + ftp_command=FTPCommand.PORT, + ftp_command_args=Port.FTP, + status_code=FTPStatusCode.OK, + ) + + assert ftp_client_service.receive(payload=payload) is False + + +def test_receive_should_fail_if_payload_is_not_ftp(ftp_client): + """Receive should return false if the node the ftp client is installed on is not an FTPPacket.""" + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + assert ftp_client_service.receive(payload=None) is False + + +def test_receive_should_ignore_payload_with_none_status_code(ftp_client): + """Receive should ignore payload with no set status code to prevent infinite send/receive loops.""" + payload: FTPPacket = FTPPacket( + ftp_command=FTPCommand.PORT, + ftp_command_args=Port.FTP, + status_code=None, + ) + ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient") + assert ftp_client_service.receive(payload=payload) is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py similarity index 62% rename from tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp.py rename to tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py index 9957b6f6..2b26c932 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_server.py @@ -1,16 +1,13 @@ -from ipaddress import IPv4Address - import pytest from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.services.ftp.ftp_client import FTPClient from primaite.simulator.system.services.ftp.ftp_server import FTPServer +from primaite.simulator.system.services.service import ServiceOperatingState @pytest.fixture(scope="function") @@ -26,34 +23,14 @@ def ftp_server() -> Node: return node -@pytest.fixture(scope="function") -def ftp_client() -> Node: - node = Computer( - hostname="ftp_client", - ip_address="192.168.1.11", - subnet_mask="255.255.255.0", - default_gateway="192.168.1.1", - operating_state=NodeOperatingState.ON, - ) - return node - - def test_create_ftp_server(ftp_server): assert ftp_server is not None - ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"] + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") assert ftp_server_service.name is "FTPServer" assert ftp_server_service.port is Port.FTP assert ftp_server_service.protocol is IPProtocol.TCP -def test_create_ftp_client(ftp_client): - assert ftp_client is not None - ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"] - assert ftp_client_service.name is "FTPClient" - assert ftp_client_service.port is Port.FTP - assert ftp_client_service.protocol is IPProtocol.TCP - - def test_ftp_server_store_file(ftp_server): """Test to make sure the FTP Server knows how to deal with request responses.""" assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") is None @@ -68,16 +45,34 @@ def test_ftp_server_store_file(ftp_server): packet_payload_size=24, ) - ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"] + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") ftp_server_service.receive(response) assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") -def test_ftp_client_store_file(ftp_client): - """Test to make sure the FTP Client knows how to deal with request responses.""" - assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None +def test_ftp_server_should_send_error_if_port_arg_is_invalid(ftp_server): + """Should fail if the port command receives an invalid port.""" + payload: FTPPacket = FTPPacket( + ftp_command=FTPCommand.PORT, + ftp_command_args=None, + packet_payload_size=24, + ) + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + assert ftp_server_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR + + +def test_ftp_server_receives_non_ftp_packet(ftp_server): + """Receive should return false if the service receives a non ftp packet.""" + response: FTPPacket = None + + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + assert ftp_server_service.receive(response) is False + + +def test_offline_ftp_server_receives_request(ftp_server): + """Receive should return false if the service is stopped.""" response: FTPPacket = FTPPacket( ftp_command=FTPCommand.STOR, ftp_command_args={ @@ -86,10 +81,9 @@ def test_ftp_client_store_file(ftp_client): "file_size": 24, }, packet_payload_size=24, - status_code=FTPStatusCode.OK, ) - ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"] - ftp_client_service.receive(response) - - assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") + ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer") + ftp_server_service.stop() + assert ftp_server_service.operating_state is ServiceOperatingState.STOPPED + assert ftp_server_service.receive(response) is False diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py index e6f0b9d9..bbccda27 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_web_server.py @@ -18,13 +18,13 @@ def web_server() -> Server: hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" ) node.software_manager.install(software_class=WebServer) - node.software_manager.software["WebServer"].start() + node.software_manager.software.get("WebServer").start() return node def test_create_web_server(web_server): assert web_server is not None - web_server_service: WebServer = web_server.software_manager.software["WebServer"] + web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service.name is "WebServer" assert web_server_service.port is Port.HTTP assert web_server_service.protocol is IPProtocol.TCP @@ -33,7 +33,7 @@ def test_create_web_server(web_server): def test_handling_get_request_not_found_path(web_server): payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path") - web_server_service: WebServer = web_server.software_manager.software["WebServer"] + web_server_service: WebServer = web_server.software_manager.software.get("WebServer") response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload) assert response.status_code == HttpStatusCode.NOT_FOUND @@ -42,7 +42,7 @@ def test_handling_get_request_not_found_path(web_server): def test_handling_get_request_home_page(web_server): payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/") - web_server_service: WebServer = web_server.software_manager.software["WebServer"] + web_server_service: WebServer = web_server.software_manager.software.get("WebServer") response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload) assert response.status_code == HttpStatusCode.OK @@ -51,7 +51,7 @@ def test_handling_get_request_home_page(web_server): def test_process_http_request_get(web_server): payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/") - web_server_service: WebServer = web_server.software_manager.software["WebServer"] + web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service._process_http_request(payload=payload) is True @@ -59,6 +59,6 @@ def test_process_http_request_get(web_server): def test_process_http_request_method_not_allowed(web_server): payload = HttpRequestPacket(request_method=HttpRequestMethod.DELETE, request_url="http://domain.com/") - web_server_service: WebServer = web_server.software_manager.software["WebServer"] + web_server_service: WebServer = web_server.software_manager.software.get("WebServer") assert web_server_service._process_http_request(payload=payload) is False