Merge 'origin/dev' into bugfix/episode-length-and-rewards

This commit is contained in:
Marek Wolan
2023-12-03 14:49:34 +00:00
47 changed files with 1199 additions and 296 deletions

View File

@@ -6,6 +6,9 @@ trigger:
- bugfix/* - bugfix/*
- release/* - release/*
pr:
autoCancel: true
drafts: false
parameters: parameters:
# https://stackoverflow.com/a/70046417 # https://stackoverflow.com/a/70046417
- name: matrix - name: matrix
@@ -85,6 +88,33 @@ stages:
primaite setup primaite setup
displayName: 'Perform PrimAITE Setup' displayName: 'Perform PrimAITE Setup'
- task: UseDotNet@2
displayName: 'Install dotnet dependencies'
inputs:
packageType: 'sdk'
version: '2.1.x'
- script: | - script: |
pytest -n auto coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml
displayName: 'Run tests' coverage xml -o coverage.xml -i
coverage html -d htmlcov -i
displayName: 'Run tests and code coverage'
- task: PublishTestResults@2
condition: succeededOrFailed()
inputs:
testRunner: JUnit
testResultsFiles: 'junit/**.xml'
testRunTitle: 'Publish test results'
- publish: $(System.DefaultWorkingDirectory)/htmlcov/
# publish the html report - so we can debug the coverage if needed
condition: ${{ item.every_time }} # should only be run once
artifact: coverage_report
- task: PublishCodeCoverageResults@2
# publish the code coverage so it can be viewed in the run coverage page
condition: ${{ item.every_time }} # should only be run once
inputs:
codeCoverageTool: Cobertura
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'

1
.gitignore vendored
View File

@@ -37,6 +37,7 @@ pip-log.txt
pip-delete-this-directory.txt pip-delete-this-directory.txt
# Unit test / coverage reports # Unit test / coverage reports
junit/
htmlcov/ htmlcov/
.tox/ .tox/
.nox/ .nox/

View File

@@ -54,7 +54,7 @@ Example
) )
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot) client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
data_manipulation_bot.run() data_manipulation_bot.run()

View File

@@ -28,7 +28,7 @@ See :ref:`Node Start up and Shut down`
node.software_manager.install(WebServer) node.software_manager.install(WebServer)
web_server: WebServer = node.software_manager.software["WebServer"] web_server: WebServer = node.software_manager.software.get("WebServer")
assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install
node.power_off() node.power_off()

View File

@@ -424,7 +424,7 @@ class NetworkACLAddRuleAction(AbstractAction):
elif permission == 2: elif permission == 2:
permission_str = "DENY" permission_str = "DENY"
else: else:
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.") _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0: if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS

View File

@@ -264,7 +264,7 @@ class FolderObservation(AbstractObservation):
while len(self.files) > num_files_per_folder: while len(self.files) > num_files_per_folder:
truncated_file = self.files.pop() truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}" msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warn(msg) _LOGGER.warning(msg)
self.default_observation = { self.default_observation = {
"health_status": 0, "health_status": 0,
@@ -438,7 +438,7 @@ class NodeObservation(AbstractObservation):
while len(self.services) > num_services_per_node: while len(self.services) > num_services_per_node:
truncated_service = self.services.pop() truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warn(msg) _LOGGER.warning(msg)
# truncate service list # truncate service list
self.folders: List[FolderObservation] = folders self.folders: List[FolderObservation] = folders
@@ -448,7 +448,7 @@ class NodeObservation(AbstractObservation):
while len(self.folders) > num_folders_per_node: while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop() truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warn(msg) _LOGGER.warning(msg)
self.nics: List[NicObservation] = nics self.nics: List[NicObservation] = nics
while len(self.nics) < num_nics_per_node: while len(self.nics) < num_nics_per_node:
@@ -456,7 +456,7 @@ class NodeObservation(AbstractObservation):
while len(self.nics) > num_nics_per_node: while len(self.nics) > num_nics_per_node:
truncated_nic = self.nics.pop() truncated_nic = self.nics.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warn(msg) _LOGGER.warning(msg)
self.logon_status: bool = logon_status self.logon_status: bool = logon_status

View File

@@ -210,7 +210,7 @@ class WebServer404Penalty(AbstractReward):
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not " f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
"found in reward config." "found in reward config."
) )
_LOGGER.warn(msg) _LOGGER.warning(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably! return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = game.ref_map_nodes[node_ref] node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref] service_uuid = game.ref_map_services[service_ref]
@@ -219,7 +219,7 @@ class WebServer404Penalty(AbstractReward):
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not" f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
" found in the simulator." " found in the simulator."
) )
_LOGGER.warn(msg) _LOGGER.warning(msg)
return DummyReward() # TODO: consider erroring here as well return DummyReward() # TODO: consider erroring here as well
return cls(node_uuid=node_uuid, service_uuid=service_uuid) return cls(node_uuid=node_uuid, service_uuid=service_uuid)

View File

@@ -252,6 +252,6 @@ class SimComponent(BaseModel):
def parent(self, new_parent: Union["SimComponent", None]) -> None: def parent(self, new_parent: Union["SimComponent", None]) -> None:
if self._parent and new_parent: if self._parent and new_parent:
msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}" msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}"
_LOGGER.warn(msg) _LOGGER.warning(msg)
raise RuntimeWarning(msg) raise RuntimeWarning(msg)
self._parent = new_parent self._parent = new_parent

View File

@@ -72,7 +72,7 @@ class Account(SimComponent):
"num_group_changes": self.num_group_changes, "num_group_changes": self.num_group_changes,
"username": self.username, "username": self.username,
"password": self.password, "password": self.password,
"account_type": self.account_type.name, "account_type": self.account_type.value,
"enabled": self.enabled, "enabled": self.enabled,
} }
) )

View File

@@ -53,7 +53,10 @@ class FileSystem(SimComponent):
original_folder_uuids = self._original_state["original_folder_uuids"] original_folder_uuids = self._original_state["original_folder_uuids"]
for uuid in original_folder_uuids: for uuid in original_folder_uuids:
if uuid in self.deleted_folders: if uuid in self.deleted_folders:
self.folders[uuid] = self.deleted_folders.pop(uuid) folder = self.deleted_folders[uuid]
self.deleted_folders.pop(uuid)
self.folders[uuid] = folder
self._folders_by_name[folder.name] = folder
# Clear any other deleted folders that aren't original (have been created by agent) # Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear() self.deleted_folders.clear()
@@ -62,7 +65,9 @@ class FileSystem(SimComponent):
current_folder_uuids = list(self.folders.keys()) current_folder_uuids = list(self.folders.keys())
for uuid in current_folder_uuids: for uuid in current_folder_uuids:
if uuid not in original_folder_uuids: if uuid not in original_folder_uuids:
folder = self.folders[uuid]
self.folders.pop(uuid) self.folders.pop(uuid)
self._folders_by_name.pop(folder.name)
# Now reset all remaining folders # Now reset all remaining folders
for folder in self.folders.values(): for folder in self.folders.values():

View File

@@ -75,7 +75,10 @@ class Folder(FileSystemItemABC):
original_file_uuids = self._original_state["original_file_uuids"] original_file_uuids = self._original_state["original_file_uuids"]
for uuid in original_file_uuids: for uuid in original_file_uuids:
if uuid in self.deleted_files: if uuid in self.deleted_files:
self.files[uuid] = self.deleted_files.pop(uuid) file = self.deleted_files[uuid]
self.deleted_files.pop(uuid)
self.files[uuid] = file
self._files_by_name[file.name] = file
# Clear any other deleted files that aren't original (have been created by agent) # Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear() self.deleted_files.clear()
@@ -84,7 +87,9 @@ class Folder(FileSystemItemABC):
current_file_uuids = list(self.files.keys()) current_file_uuids = list(self.files.keys())
for uuid in current_file_uuids: for uuid in current_file_uuids:
if uuid not in original_file_uuids: if uuid not in original_file_uuids:
file = self.files[uuid]
self.files.pop(uuid) self.files.pop(uuid)
self._files_by_name.pop(file.name)
# Now reset all remaining files # Now reset all remaining files
for file in self.files.values(): for file in self.files.values():

View File

@@ -51,14 +51,22 @@ def client_server_routed() -> Network:
# Client 1 # Client 1
client_1 = Computer( client_1 = Computer(
hostname="client_1", ip_address="192.168.2.2", subnet_mask="255.255.255.0", default_gateway="192.168.2.1" hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
operating_state=NodeOperatingState.ON,
) )
client_1.power_on() client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
# Server 1 # Server 1
server_1 = Server( server_1 = Server(
hostname="server_1", ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
) )
server_1.power_on() server_1.power_on()
network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1]) network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
@@ -139,7 +147,7 @@ def arcd_uc2_network() -> Network:
client_1.power_on() client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1]) network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot) client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
db_manipulation_bot.configure( db_manipulation_bot.configure(
server_ip_address=IPv4Address("192.168.1.14"), server_ip_address=IPv4Address("192.168.1.14"),
payload="DELETE", payload="DELETE",
@@ -157,7 +165,7 @@ def arcd_uc2_network() -> Network:
operating_state=NodeOperatingState.ON, operating_state=NodeOperatingState.ON,
) )
client_2.power_on() client_2.power_on()
web_browser = client_2.software_manager.software["WebBrowser"] web_browser = client_2.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/" web_browser.target_url = "http://arcd.com/users/"
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2]) network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
@@ -241,7 +249,7 @@ def arcd_uc2_network() -> Network:
# noqa # noqa
] ]
database_server.software_manager.install(DatabaseService) database_server.software_manager.install(DatabaseService)
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
database_service.start() database_service.start()
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
database_service._process_sql(ddl, None) # noqa database_service._process_sql(ddl, None) # noqa
@@ -260,7 +268,7 @@ def arcd_uc2_network() -> Network:
web_server.power_on() web_server.power_on()
web_server.software_manager.install(DatabaseClient) web_server.software_manager.install(DatabaseClient)
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] database_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.1.14")) database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2]) network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
database_client.run() database_client.run()
@@ -269,7 +277,7 @@ def arcd_uc2_network() -> Network:
web_server.software_manager.install(WebServer) web_server.software_manager.install(WebServer)
# register the web_server to a domain # register the web_server to a domain
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa dns_server_service: DNSServer = domain_controller.software_manager.software.get("DNSServer") # noqa
dns_server_service.dns_register("arcd.com", web_server.ip_address) dns_server_service.dns_register("arcd.com", web_server.ip_address)
# Backup Server # Backup Server

View File

@@ -5,6 +5,8 @@ def convert_bytes_to_megabits(B: Union[int, float]) -> float: # noqa - Keep it
""" """
Convert Bytes (file size) to Megabits (data transfer). Convert Bytes (file size) to Megabits (data transfer).
Technically Mebibits - but for simplicity sake, we'll call it megabit
:param B: The file size in Bytes. :param B: The file size in Bytes.
:return: File bits to transfer in Megabits. :return: File bits to transfer in Megabits.
""" """

View File

@@ -73,7 +73,8 @@ class DatabaseClient(Application):
if not self.connected: if not self.connected:
return self._connect(self.server_ip_address, self.server_password) return self._connect(self.server_ip_address, self.server_password)
return False # already connected
return True
def _connect( def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
@@ -107,7 +108,7 @@ class DatabaseClient(Application):
def disconnect(self): def disconnect(self):
"""Disconnect from the Database Service.""" """Disconnect from the Database Service."""
if self.connected and self.operating_state.RUNNING: if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
software_manager: SoftwareManager = self.software_manager software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager( software_manager.send_payload_to_session_manager(
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
@@ -186,6 +187,9 @@ class DatabaseClient(Application):
:param session_id: The session id the payload relates to. :param session_id: The session id the payload relates to.
:return: True. :return: True.
""" """
if not self._can_perform_action():
return False
if isinstance(payload, dict) and payload.get("type"): if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response": if payload["type"] == "connect_response":
self.connected = payload["response"] == True self.connected = payload["response"] == True

View File

@@ -99,7 +99,7 @@ class WebBrowser(Application):
return False return False
# get the IP address of the domain name via DNS # get the IP address of the domain name via DNS
dns_client: DNSClient = self.software_manager.software["DNSClient"] dns_client: DNSClient = self.software_manager.software.get("DNSClient")
domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname) domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname)
# if domain does not exist, the request fails # if domain does not exist, the request fails

View File

@@ -80,7 +80,7 @@ class DatabaseService(Service):
return False return False
software_manager: SoftwareManager = self.software_manager software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"] ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# send backup copy of database file to FTP server # send backup copy of database file to FTP server
response = ftp_client_service.send_file( response = ftp_client_service.send_file(
@@ -104,7 +104,7 @@ class DatabaseService(Service):
return False return False
software_manager: SoftwareManager = self.software_manager software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"] ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# retrieve backup file from backup server # retrieve backup file from backup server
response = ftp_client_service.request_file( response = ftp_client_service.request_file(

View File

@@ -8,7 +8,6 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@@ -53,8 +52,7 @@ class FTPClient(FTPServiceABC):
:type: session_id: Optional[str] :type: session_id: Optional[str]
""" """
# if client service is down, return error # if client service is down, return error
if self.operating_state != ServiceOperatingState.RUNNING: if not self._can_perform_action():
self.sys_log.error("FTP Client is not running")
payload.status_code = FTPStatusCode.ERROR payload.status_code = FTPStatusCode.ERROR
return payload return payload
@@ -81,8 +79,7 @@ class FTPClient(FTPServiceABC):
:type: is_reattempt: Optional[bool] :type: is_reattempt: Optional[bool]
""" """
# make sure the service is running before attempting # make sure the service is running before attempting
if self.operating_state != ServiceOperatingState.RUNNING: if not self._can_perform_action():
self.sys_log.error(f"FTPClient not running for {self.sys_log.hostname}")
return False return False
# normally FTP will choose a random port for the transfer, but using the FTP command port will do for now # normally FTP will choose a random port for the transfer, but using the FTP command port will do for now
@@ -282,8 +279,11 @@ class FTPClient(FTPServiceABC):
This helps prevent an FTP request loop - FTP client and servers can exist on This helps prevent an FTP request loop - FTP client and servers can exist on
the same node. the same node.
""" """
if not self._can_perform_action():
return False
if payload.status_code is None: if payload.status_code is None:
self.sys_log.error(f"FTP Server could not be found - Error Code: {payload.status_code.value}") self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
return False return False
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}") self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")

View File

@@ -6,7 +6,6 @@ from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPS
from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@@ -58,8 +57,7 @@ class FTPServer(FTPServiceABC):
payload.status_code = FTPStatusCode.ERROR payload.status_code = FTPStatusCode.ERROR
# if server service is down, return error # if server service is down, return error
if self.operating_state != ServiceOperatingState.RUNNING: if not self._can_perform_action():
self.sys_log.error("FTP Server not running")
return payload return payload
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}") self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
@@ -95,6 +93,9 @@ class FTPServer(FTPServiceABC):
self.sys_log.error(f"{payload} is not an FTP packet") self.sys_log.error(f"{payload} is not an FTP packet")
return False return False
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
""" """
Ignore ftp payload if status code is defined. Ignore ftp payload if status code is defined.
@@ -102,9 +103,6 @@ class FTPServer(FTPServiceABC):
prevents an FTP request loop - FTP client and servers can exist on prevents an FTP request loop - FTP client and servers can exist on
the same node. the same node.
""" """
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
if payload.status_code is not None: if payload.status_code is not None:
return False return False

View File

@@ -109,8 +109,8 @@ class Service(IOSoftware):
""" """
state = super().describe_state() state = super().describe_state()
state["operating_state"] = self.operating_state.value state["operating_state"] = self.operating_state.value
state["health_state_actual"] = self.health_state_actual state["health_state_actual"] = self.health_state_actual.value
state["health_state_visible"] = self.health_state_visible state["health_state_visible"] = self.health_state_visible.value
return state return state
def stop(self) -> None: def stop(self) -> None:

View File

@@ -119,7 +119,7 @@ class WebServer(Service):
if path.startswith("users"): if path.startswith("users"):
# get data from DatabaseServer # get data from DatabaseServer
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"] db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
# get all users # get all users
if db_client.query("SELECT"): if db_client.query("SELECT"):
# query succeeded # query succeeded

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Union from typing import Any, Dict, Tuple, Union
import pytest import pytest
import yaml import yaml
@@ -12,6 +12,11 @@ from primaite.session.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite # from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession
from primaite.simulator.network.container import Network from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.networks import arcd_uc2_network from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.network.transmission.transport_layer import Port
@@ -29,7 +34,7 @@ from primaite import PRIMAITE_PATHS
# PrimAITE v3 stuff # PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.base import Link, Node
class TestService(Service): class TestService(Service):
@@ -122,3 +127,110 @@ def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path()) monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0] config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path) return TempPrimaiteSession.from_config(config_path=config_path)
@pytest.fixture(scope="function")
def client_server() -> Tuple[Computer, Server]:
# Create Computer
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.0.1",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
# Create Server
server = Server(
hostname="server", ip_address="192.168.0.2", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
)
# Connect Computer and Server
computer_nic = computer.nics[next(iter(computer.nics))]
server_nic = server.nics[next(iter(server.nics))]
link = Link(endpoint_a=computer_nic, endpoint_b=server_nic)
# Should be linked
assert link.is_up
return computer, server
@pytest.fixture(scope="function")
def example_network() -> Network:
"""
Create the network used for testing.
Should only contain the nodes and links.
This would act as the base network and services and applications are installed in the relevant test file,
-------------- --------------
| client_1 |----- ----| server_1 |
-------------- | -------------- -------------- -------------- | --------------
------| switch_1 |------| router_1 |------| switch_2 |------
-------------- | -------------- -------------- -------------- | --------------
| client_2 |---- ----| server_2 |
-------------- --------------
"""
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON)
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON)
network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
# Database Server
server_2 = Server(
hostname="server_2",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=server_2.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
return network

View File

@@ -2,6 +2,7 @@
import tempfile import tempfile
from pathlib import Path from pathlib import Path
import pytest
import yaml import yaml
from stable_baselines3 import PPO from stable_baselines3 import PPO
@@ -10,6 +11,7 @@ from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv from primaite.session.environment import PrimaiteGymEnv
# @pytest.mark.skip(reason="no way of currently testing this")
def test_sb3_compatibility(): def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent.""" """Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f: with open(example_config_path(), "r") as f:

View File

@@ -11,6 +11,7 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
# @pytest.mark.skip(reason="no way of currently testing this")
class TestPrimaiteSession: class TestPrimaiteSession:
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_creating_session(self, temp_primaite_session): def test_creating_session(self, temp_primaite_session):

View File

@@ -8,13 +8,13 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor
def test_data_manipulation(uc2_network): def test_data_manipulation(uc2_network):
"""Tests the UC2 data manipulation scenario end-to-end. Is a work in progress.""" """Tests the UC2 data manipulation scenario end-to-end. Is a work in progress."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1") client_1: Computer = uc2_network.get_node_by_hostname("client_1")
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"] db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
database_server: Server = uc2_network.get_node_by_hostname("database_server") database_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = database_server.software_manager.software["DatabaseService"] db_service: DatabaseService = database_server.software_manager.software.get("DatabaseService")
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_service.backup_database() db_service.backup_database()

View File

@@ -16,3 +16,9 @@ def test_link_up():
assert nic_a.enabled assert nic_a.enabled
assert nic_b.enabled assert nic_b.enabled
assert link.is_up assert link.is_up
def test_ping_between_computer_and_server(client_server):
computer, server = client_server
assert computer.ping(target_ip_address=server.nics[next(iter(server.nics))].ip_address)

View File

@@ -2,6 +2,28 @@ import pytest
from primaite.simulator.network.container import Network from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC, Node from primaite.simulator.network.hardware.base import NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.networks import client_server_routed
def test_network(example_network):
network: Network = example_network
client_1: Computer = network.get_node_by_hostname("client_1")
client_2: Computer = network.get_node_by_hostname("client_2")
server_1: Server = network.get_node_by_hostname("server_1")
server_2: Server = network.get_node_by_hostname("server_2")
assert client_1.ping(client_2.ethernet_port[1].ip_address)
assert client_2.ping(client_1.ethernet_port[1].ip_address)
assert server_1.ping(server_2.ethernet_port[1].ip_address)
assert server_2.ping(server_1.ethernet_port[1].ip_address)
assert client_1.ping(server_1.ethernet_port[1].ip_address)
assert client_2.ping(server_1.ethernet_port[1].ip_address)
assert client_1.ping(server_2.ethernet_port[1].ip_address)
assert client_2.ping(server_2.ethernet_port[1].ip_address)
def test_adding_removing_nodes(): def test_adding_removing_nodes():

View File

@@ -18,7 +18,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
) )
computer.software_manager.install(application_class) computer.software_manager.install(application_class)
app = computer.software_manager.software["TestApplication"] app = computer.software_manager.software.get("TestApplication")
app.run() app.run()
return app, computer return app, computer
@@ -35,7 +35,7 @@ def test_service_on_offline_node(application_class):
) )
computer.software_manager.install(application_class) computer.software_manager.install(application_class)
app: Application = computer.software_manager.software["TestApplication"] app: Application = computer.software_manager.software.get("TestApplication")
computer.power_off() computer.power_off()

View File

@@ -10,10 +10,10 @@ from primaite.simulator.system.services.service import ServiceOperatingState
def test_database_client_server_connection(uc2_network): def test_database_client_server_connection(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_server: Server = uc2_network.get_node_by_hostname("database_server") db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
assert len(db_service.connections) == 1 assert len(db_service.connections) == 1
@@ -23,10 +23,10 @@ def test_database_client_server_connection(uc2_network):
def test_database_client_server_correct_password(uc2_network): def test_database_client_server_correct_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_server: Server = uc2_network.get_node_by_hostname("database_server") db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client.disconnect() db_client.disconnect()
@@ -40,10 +40,10 @@ def test_database_client_server_correct_password(uc2_network):
def test_database_client_server_incorrect_password(uc2_network): def test_database_client_server_incorrect_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_server: Server = uc2_network.get_node_by_hostname("database_server") db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client.disconnect() db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321") db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
@@ -56,7 +56,7 @@ def test_database_client_server_incorrect_password(uc2_network):
def test_database_client_query(uc2_network): def test_database_client_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date.""" """Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected assert db_client.connected
@@ -66,13 +66,13 @@ def test_database_client_query(uc2_network):
def test_create_database_backup(uc2_network): def test_create_database_backup(uc2_network):
"""Run the backup_database method and check if the FTP server has the relevant file.""" """Run the backup_database method and check if the FTP server has the relevant file."""
db_server: Server = uc2_network.get_node_by_hostname("database_server") db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
# back up should be created # back up should be created
assert db_service.backup_database() is True assert db_service.backup_database() is True
backup_server: Server = uc2_network.get_node_by_hostname("backup_server") backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"] ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
# backup file should exist in the backup server # backup file should exist in the backup server
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
@@ -81,7 +81,7 @@ def test_create_database_backup(uc2_network):
def test_restore_backup(uc2_network): def test_restore_backup(uc2_network):
"""Run the restore_backup method and check if the backup is properly restored.""" """Run the restore_backup method and check if the backup is properly restored."""
db_server: Server = uc2_network.get_node_by_hostname("database_server") db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
# create a back up # create a back up
assert db_service.backup_database() is True assert db_service.backup_database() is True
@@ -100,13 +100,13 @@ def test_restore_backup(uc2_network):
def test_database_client_cannot_query_offline_database_server(uc2_network): def test_database_client_cannot_query_offline_database_server(uc2_network):
"""Tests DB query across the network returns HTTP status 404 when db server is offline.""" """Tests DB query across the network returns HTTP status 404 when db server is offline."""
db_server: Server = uc2_network.get_node_by_hostname("database_server") db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
assert db_server.operating_state is NodeOperatingState.ON assert db_server.operating_state is NodeOperatingState.ON
assert db_service.operating_state is ServiceOperatingState.RUNNING assert db_service.operating_state is ServiceOperatingState.RUNNING
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected assert db_client.connected
assert db_client.query("SELECT") is True assert db_client.query("SELECT") is True

View File

@@ -1,3 +1,8 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.server import Server
@@ -6,12 +11,31 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.service import ServiceOperatingState
def test_dns_client_server(uc2_network): @pytest.fixture(scope="function")
client_1: Computer = uc2_network.get_node_by_hostname("client_1") def dns_client_and_dns_server(client_server) -> Tuple[DNSClient, Computer, DNSServer, Server]:
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller") computer, server = client_server
dns_client: DNSClient = client_1.software_manager.software["DNSClient"] # Install DNS Client on computer
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"] computer.software_manager.install(DNSClient)
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
dns_client.start()
# set server as DNS Server
dns_client.dns_server = IPv4Address(server.nics.get(next(iter(server.nics))).ip_address)
# Install DNS Server on server
server.software_manager.install(DNSServer)
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
dns_server.start()
# register arcd.com as a domain
dns_server.dns_register(
domain_name="arcd.com", domain_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address)
)
return dns_client, computer, dns_server, server
def test_dns_client_server(dns_client_and_dns_server):
dns_client, computer, dns_server, server = dns_client_and_dns_server
assert dns_client.operating_state == ServiceOperatingState.RUNNING assert dns_client.operating_state == ServiceOperatingState.RUNNING
assert dns_server.operating_state == ServiceOperatingState.RUNNING assert dns_server.operating_state == ServiceOperatingState.RUNNING
@@ -29,12 +53,8 @@ def test_dns_client_server(uc2_network):
assert len(dns_client.dns_cache) == 1 assert len(dns_client.dns_cache) == 1
def test_dns_client_requests_offline_dns_server(uc2_network): def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server):
client_1: Computer = uc2_network.get_node_by_hostname("client_1") dns_client, computer, dns_server, server = dns_client_and_dns_server
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
assert dns_client.operating_state == ServiceOperatingState.RUNNING assert dns_client.operating_state == ServiceOperatingState.RUNNING
assert dns_server.operating_state == ServiceOperatingState.RUNNING assert dns_server.operating_state == ServiceOperatingState.RUNNING
@@ -48,12 +68,12 @@ def test_dns_client_requests_offline_dns_server(uc2_network):
assert len(dns_client.dns_cache) == 1 assert len(dns_client.dns_cache) == 1
dns_client.dns_cache = {} dns_client.dns_cache = {}
domain_controller.power_off() server.power_off()
for i in range(domain_controller.shut_down_duration + 1): for i in range(server.shut_down_duration + 1):
uc2_network.apply_timestep(timestep=i) server.apply_timestep(timestep=i)
assert domain_controller.operating_state == NodeOperatingState.OFF assert server.operating_state == NodeOperatingState.OFF
assert dns_server.operating_state == ServiceOperatingState.STOPPED assert dns_server.operating_state == ServiceOperatingState.STOPPED
# this time it should not cache because dns server is not online # this time it should not cache because dns server is not online

View File

@@ -1,4 +1,7 @@
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.server import Server
@@ -7,18 +10,31 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.service import ServiceOperatingState
def test_ftp_client_store_file_in_server(uc2_network): @pytest.fixture(scope="function")
def ftp_client_and_ftp_server(client_server) -> Tuple[FTPClient, Computer, FTPServer, Server]:
computer, server = client_server
# Install FTP Client service on computer
computer.software_manager.install(FTPClient)
ftp_client: FTPClient = computer.software_manager.software.get("FTPClient")
ftp_client.start()
# Install FTP Server service on server
server.software_manager.install(FTPServer)
ftp_server: FTPServer = server.software_manager.software.get("FTPServer")
ftp_server.start()
return ftp_client, computer, ftp_server, server
def test_ftp_client_store_file_in_server(ftp_client_and_ftp_server):
""" """
Test checks to see if the client is able to store files in the backup server. Test checks to see if the client is able to store files in the backup server.
""" """
client_1: Computer = uc2_network.get_node_by_hostname("client_1") ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"]
assert ftp_client.operating_state == ServiceOperatingState.RUNNING assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING assert ftp_server.operating_state == ServiceOperatingState.RUNNING
# create file on ftp client # create file on ftp client
ftp_client.file_system.create_file(file_name="test_file.txt") ftp_client.file_system.create_file(file_name="test_file.txt")
@@ -28,61 +44,53 @@ def test_ftp_client_store_file_in_server(uc2_network):
src_file_name="test_file.txt", src_file_name="test_file.txt",
dest_folder_name="client_1_backup", dest_folder_name="client_1_backup",
dest_file_name="test_file.txt", dest_file_name="test_file.txt",
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address, dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
) )
assert ftp_server_service.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt") assert ftp_server.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt")
def test_ftp_client_retrieve_file_from_server(uc2_network): def test_ftp_client_retrieve_file_from_server(ftp_client_and_ftp_server):
""" """
Test checks to see if the client is able to retrieve files from the backup server. Test checks to see if the client is able to retrieve files from the backup server.
""" """
client_1: Computer = uc2_network.get_node_by_hostname("client_1") ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"]
assert ftp_client.operating_state == ServiceOperatingState.RUNNING assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING assert ftp_server.operating_state == ServiceOperatingState.RUNNING
# create file on ftp server # create file on ftp server
ftp_server_service.file_system.create_file(file_name="test_file.txt", folder_name="file_share") ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
assert ftp_client.request_file( assert ftp_client.request_file(
src_folder_name="file_share", src_folder_name="file_share",
src_file_name="test_file.txt", src_file_name="test_file.txt",
dest_folder_name="downloads", dest_folder_name="downloads",
dest_file_name="test_file.txt", dest_file_name="test_file.txt",
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address, dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
) )
# client should have retrieved the file # client should have retrieved the file
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="test_file.txt") assert ftp_client.file_system.get_file(folder_name="downloads", file_name="test_file.txt")
def test_ftp_client_tries_to_connect_to_offline_server(uc2_network): def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server):
"""Test checks to make sure that the client can't do anything when the server is offline.""" """Test checks to make sure that the client can't do anything when the server is offline."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1") ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"]
assert ftp_client.operating_state == ServiceOperatingState.RUNNING assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING assert ftp_server.operating_state == ServiceOperatingState.RUNNING
# create file on ftp server # create file on ftp server
ftp_server_service.file_system.create_file(file_name="test_file.txt", folder_name="file_share") ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
backup_server.power_off() server.power_off()
for i in range(backup_server.shut_down_duration + 1): for i in range(server.shut_down_duration + 1):
uc2_network.apply_timestep(timestep=i) server.apply_timestep(timestep=i)
assert ftp_client.operating_state == ServiceOperatingState.RUNNING assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server_service.operating_state == ServiceOperatingState.STOPPED assert ftp_server.operating_state == ServiceOperatingState.STOPPED
assert ( assert (
ftp_client.request_file( ftp_client.request_file(
@@ -90,7 +98,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(uc2_network):
src_file_name="test_file.txt", src_file_name="test_file.txt",
dest_folder_name="downloads", dest_folder_name="downloads",
dest_file_name="test_file.txt", dest_file_name="test_file.txt",
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address, dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
) )
is False is False
) )

View File

@@ -17,7 +17,7 @@ def populated_node(
) )
server.software_manager.install(service_class) server.software_manager.install(service_class)
service = server.software_manager.software["TestService"] service = server.software_manager.software.get("TestService")
service.start() service.start()
return server, service return server, service
@@ -34,7 +34,7 @@ def test_service_on_offline_node(service_class):
) )
computer.software_manager.install(service_class) computer.software_manager.install(service_class)
service: Service = computer.software_manager.software["TestService"] service: Service = computer.software_manager.software.get("TestService")
computer.power_off() computer.power_off()

View File

@@ -1,104 +1,118 @@
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.http import HttpStatusCode from primaite.simulator.network.protocols.http import HttpStatusCode
from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
def test_web_page_home_page(uc2_network): @pytest.fixture(scope="function")
"""Test to see if the browser is able to open the main page of the web server.""" def web_client_and_web_server(client_server) -> Tuple[WebBrowser, Computer, WebServer, Server]:
client_1: Computer = uc2_network.get_node_by_hostname("client_1") computer, server = client_server
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_client.target_url = "http://arcd.com/"
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage() is True # Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.run()
# latest reponse should have status code 200 # Install DNS Client service on computer
assert web_client.latest_response is not None computer.software_manager.install(DNSClient)
assert web_client.latest_response.status_code == HttpStatusCode.OK dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
# set dns server
dns_client.dns_server = server.nics[next(iter(server.nics))].ip_address
# Install Web Server service on server
server.software_manager.install(WebServer)
web_server_service: WebServer = server.software_manager.software.get("WebServer")
web_server_service.start()
# Install DNS Server service on server
server.software_manager.install(DNSServer)
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
# register arcd.com to DNS
dns_server.dns_register(domain_name="arcd.com", domain_ip_address=server.nics[next(iter(server.nics))].ip_address)
return web_browser, computer, web_server_service, server
def test_web_page_get_users_page_request_with_domain_name(uc2_network): def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_server):
"""Test to see if the client can handle requests with domain names""" """Test to see if the client can handle requests with domain names"""
client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
web_client.target_url = "http://arcd.com/users/"
assert web_client.get_webpage() is True web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
# latest response should have status code 200 # latest response should have status code 200
assert web_client.latest_response is not None assert web_browser_app.latest_response is not None
assert web_client.latest_response.status_code == HttpStatusCode.OK assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
def test_web_page_get_users_page_request_with_ip_address(uc2_network): def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_server):
"""Test to see if the client can handle requests that use ip_address.""" """Test to see if the client can handle requests that use ip_address."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
web_browser_app.target_url = f"http://{web_server_ip}/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address assert web_browser_app.get_webpage() is True
web_client.target_url = f"http://{web_server_ip}/users/"
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage() is True
# latest response should have status code 200 # latest response should have status code 200
assert web_client.latest_response is not None assert web_browser_app.latest_response is not None
assert web_client.latest_response.status_code == HttpStatusCode.OK assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
def test_web_page_request_from_shut_down_server(uc2_network): def test_web_page_request_from_shut_down_server(web_client_and_web_server):
"""Test to see that the web server does not respond when the server is off.""" """Test to see that the web server does not respond when the server is off."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_client.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True
assert web_client.get_webpage("http://arcd.com/users/") is True
# latest response should have status code 200 # latest response should have status code 200
assert web_client.latest_response.status_code == HttpStatusCode.OK assert web_browser_app.latest_response is not None
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
web_server.power_off() server.power_off()
for i in range(web_server.shut_down_duration + 1): server.power_off()
uc2_network.apply_timestep(timestep=i)
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
# node should be off # node should be off
assert web_server.operating_state is NodeOperatingState.OFF assert server.operating_state is NodeOperatingState.OFF
assert web_client.get_webpage("http://arcd.com/users/") is False assert web_browser_app.get_webpage() is False
assert web_client.latest_response.status_code == HttpStatusCode.NOT_FOUND assert web_browser_app.latest_response.status_code == HttpStatusCode.NOT_FOUND
def test_web_page_request_from_closed_web_browser(uc2_network): def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
client_1: Computer = uc2_network.get_node_by_hostname("client_1") web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_server: Server = uc2_network.get_node_by_hostname("web_server") assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
web_browser_app.target_url = f"http://arcd.com/"
assert web_client.operating_state == ApplicationOperatingState.RUNNING assert web_browser_app.get_webpage() is True
assert web_client.get_webpage("http://arcd.com/users/") is True
# latest response should have status code 200 # latest response should have status code 200
assert web_client.latest_response.status_code == HttpStatusCode.OK assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
web_client.close() web_browser_app.close()
# node should be off # node should be off
assert web_client.operating_state is ApplicationOperatingState.CLOSED assert web_browser_app.operating_state is ApplicationOperatingState.CLOSED
assert web_client.get_webpage("http://arcd.com/users/") is False assert web_browser_app.get_webpage() is False

View File

@@ -0,0 +1,108 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.base import Link
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
@pytest.fixture(scope="function")
def web_client_web_server_database(example_network) -> Tuple[Computer, Server, Server]:
# add rules to network router
router_1: Router = example_network.get_node_by_hostname("router_1")
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
# Allow DNS requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
# Allow FTP requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2)
# Open port 80 for web server
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
# Create Computer
computer: Computer = example_network.get_node_by_hostname("client_1")
# Create Web Server
web_server: Server = example_network.get_node_by_hostname("server_1")
# Create Database Server
db_server = example_network.get_node_by_hostname("server_2")
# Get the NICs
computer_nic = computer.nics[next(iter(computer.nics))]
server_nic = web_server.nics[next(iter(web_server.nics))]
db_server_nic = db_server.nics[next(iter(db_server.nics))]
# Connect Computer and Server
link_computer_server = Link(endpoint_a=computer_nic, endpoint_b=server_nic)
# Should be linked
assert link_computer_server.is_up
# Connect database server and web server
link_server_db = Link(endpoint_a=server_nic, endpoint_b=db_server_nic)
# Should be linked
assert link_computer_server.is_up
assert link_server_db.is_up
# Install DatabaseService on db server
db_server.software_manager.install(DatabaseService)
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_service.start()
# Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
web_browser.run()
# Install DNS Client service on computer
computer.software_manager.install(DNSClient)
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
# set dns server
dns_client.dns_server = web_server.nics[next(iter(web_server.nics))].ip_address
# Install Web Server service on web server
web_server.software_manager.install(WebServer)
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
web_server_service.start()
# Install DNS Server service on web server
web_server.software_manager.install(DNSServer)
dns_server: DNSServer = web_server.software_manager.software.get("DNSServer")
# register arcd.com to DNS
dns_server.dns_register(
domain_name="arcd.com", domain_ip_address=web_server.nics[next(iter(web_server.nics))].ip_address
)
# Install DatabaseClient service on web server
web_server.software_manager.install(DatabaseClient)
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_client.server_ip_address = IPv4Address(db_server_nic.ip_address) # set IP address of Database Server
db_client.run()
assert dns_client.check_domain_exists("arcd.com")
assert db_client.connect()
return computer, web_server, db_server
def test_web_client_requests_users(web_client_web_server_database):
computer, web_server, db_server = web_client_web_server_database
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
assert web_browser.get_webpage()

View File

@@ -1,18 +1,140 @@
"""Test the account module of the simulator.""" """Test the account module of the simulator."""
import pytest
from primaite.simulator.domain.account import Account, AccountType from primaite.simulator.domain.account import Account, AccountType
def test_account_serialise(): @pytest.fixture(scope="function")
def account() -> Account:
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
acct.set_original_state()
return acct
def test_original_state(account):
"""Test the original state - see if it resets properly"""
account.log_on()
account.log_off()
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
account.reset_component_for_episode(episode=1)
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_on()
account.log_off()
account.disable()
account.set_original_state()
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 2
account.reset_component_for_episode(episode=2)
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
def test_enable(account):
"""Should enable the account."""
account.enabled = False
account.enable()
assert account.enabled is True
def test_disable(account):
"""Should disable the account."""
account.enabled = True
account.disable()
assert account.enabled is False
def test_log_on_increments(account):
"""Should increase the log on value by 1."""
account.num_logons = 0
account.log_on()
assert account.num_logons is 1
def test_log_off_increments(account):
"""Should increase the log on value by 1."""
account.num_logoffs = 0
account.log_off()
assert account.num_logoffs is 1
def test_account_serialise(account):
"""Test that an account can be serialised. If pydantic throws error then this test fails.""" """Test that an account can be serialised. If pydantic throws error then this test fails."""
acct = Account(username="Jake", password="JakePass1!", account_type=AccountType.USER) serialised = account.model_dump_json()
serialised = acct.model_dump_json()
print(serialised) print(serialised)
def test_account_deserialise(): def test_account_deserialise(account):
"""Test that an account can be deserialised. The test fails if pydantic throws an error.""" """Test that an account can be deserialised. The test fails if pydantic throws an error."""
acct_json = ( acct_json = (
'{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,' '{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,'
'"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"request_manager":null}' '"username":"Jake","password":"totally_hashed_password","account_type":2,"status":2,"request_manager":null}'
) )
acct = Account.model_validate_json(acct_json) assert Account.model_validate_json(acct_json)
def test_describe_state(account):
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_off()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False

View File

@@ -185,6 +185,38 @@ def test_get_file(file_system):
file_system.show(full=True) file_system.show(full=True)
def test_reset_file_system(file_system):
# file and folder that existed originally
file_system.create_file(file_name="test_file.zip")
file_system.create_folder(folder_name="test_folder")
file_system.set_original_state()
# create a new file
file_system.create_file(file_name="new_file.txt")
# create a new folder
file_system.create_folder(folder_name="new_folder")
# delete the file that existed originally
file_system.delete_file(folder_name="root", file_name="test_file.zip")
assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
# delete the folder that existed originally
file_system.delete_folder(folder_name="test_folder")
assert file_system.get_folder(folder_name="test_folder") is None
# reset
file_system.reset_component_for_episode(episode=1)
# deleted original file and folder should be back
assert file_system.get_file(folder_name="root", file_name="test_file.zip")
assert file_system.get_folder(folder_name="test_folder")
# new file and folder should be removed
assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
assert file_system.get_folder(folder_name="new_folder") is None
@pytest.mark.skip(reason="Skipping until we tackle serialisation") @pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_serialisation(file_system): def test_serialisation(file_system):
"""Test to check that the object serialisation works correctly.""" """Test to check that the object serialisation works correctly."""

View File

@@ -0,0 +1,17 @@
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.switch import Switch
@pytest.fixture(scope="function")
def switch() -> Switch:
switch: Switch = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
switch.show()
return switch
def test_describe_state(switch):
state = switch.describe_state()
assert len(state.get("ports")) is 8
assert state.get("num_ports") is 8

View File

@@ -3,6 +3,66 @@ import json
import pytest import pytest
from primaite.simulator.network.container import Network from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
@pytest.fixture(scope="function")
def network(example_network) -> Network:
assert len(example_network.routers) is 1
assert len(example_network.switches) is 2
assert len(example_network.computers) is 2
assert len(example_network.servers) is 2
example_network.set_original_state()
example_network.show()
return example_network
def test_describe_state(network):
"""Test that describe state works."""
state = network.describe_state()
assert len(state["nodes"]) is 7
assert len(state["links"]) is 6
def test_reset_network(network):
"""
Test that the network is properly reset.
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
etc are also removed/reinstalled
"""
state_before = network.describe_state()
client_1: Computer = network.get_node_by_hostname("client_1")
server_1: Computer = network.get_node_by_hostname("server_1")
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
client_1.power_off()
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
server_1.power_off()
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
assert network.describe_state() is not state_before
network.reset_component_for_episode(episode=1)
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
state_before, sort_keys=True, indent=2
)
def test_creating_container(): def test_creating_container():
@@ -10,11 +70,46 @@ def test_creating_container():
net = Network() net = Network()
assert net.nodes == {} assert net.nodes == {}
assert net.links == {} assert net.links == {}
net.show()
@pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_apply_timestep_to_nodes(network):
def test_describe_state(): """Calling apply_timestep on the network should apply to the nodes within it."""
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable.""" client_1: Computer = network.get_node_by_hostname("client_1")
net = Network() assert client_1.operating_state is NodeOperatingState.ON
state = net.describe_state()
json.dumps(state) # if this function call raises an error, the test fails, state was not JSON-serialisable client_1.power_off()
for i in range(client_1.shut_down_duration + 1):
network.apply_timestep(timestep=i)
assert client_1.operating_state is NodeOperatingState.OFF
def test_removing_node_that_does_not_exist(network):
"""Node that does not exist on network should not affect existing nodes."""
assert len(network.nodes) is 7
network.remove_node(Node(hostname="new_node"))
assert len(network.nodes) is 7
def test_remove_node(network):
"""Remove node should remove the correct node."""
assert len(network.nodes) is 7
client_1: Computer = network.get_node_by_hostname("client_1")
network.remove_node(client_1)
assert network.get_node_by_hostname("client_1") is None
assert len(network.nodes) is 6
def test_remove_link(network):
"""Remove link should remove the correct link."""
assert len(network.links) is 6
link: Link = network.links.get(next(iter(network.links)))
network.remove_link(link)
assert len(network.links) is 5
assert network.links.get(link.uuid) is None

View File

@@ -0,0 +1,11 @@
from primaite.simulator.network.utils import convert_bytes_to_megabits, convert_megabits_to_bytes
def test_convert_bytes_to_megabits():
assert round(convert_bytes_to_megabits(B=131072), 5) == float(1)
assert round(convert_bytes_to_megabits(B=69420), 5) == float(0.52963)
def test_convert_megabits_to_bytes():
assert round(convert_megabits_to_bytes(Mbits=1), 5) == float(131072)
assert round(convert_megabits_to_bytes(Mbits=float(0.52963)), 5) == float(69419.66336)

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
from typing import Tuple, Union
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
@pytest.fixture(scope="function")
def database_client_on_computer() -> Tuple[DatabaseClient, Computer]:
computer = Computer(
hostname="db_node", ip_address="192.168.0.1", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
)
computer.software_manager.install(DatabaseClient)
database_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
database_client.run()
return database_client, computer
def test_creation(database_client_on_computer):
database_client, computer = database_client_on_computer
database_client.describe_state()
def test_connect_when_client_is_closed(database_client_on_computer):
"""Database client should not connect when it is not running."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
assert database_client.connect() is False
def test_connect_to_database_fails_on_reattempt(database_client_on_computer):
"""Database client should return False when the attempt to connect fails."""
database_client, computer = database_client_on_computer
database_client.connected = False
assert database_client._connect(server_ip_address=IPv4Address("192.168.0.1"), is_reattempt=True) is False
def test_disconnect_when_client_is_closed(database_client_on_computer):
"""Database client disconnect should not do anything when it is not running."""
database_client, computer = database_client_on_computer
database_client.connected = True
assert database_client.server_ip_address is not None
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
database_client.disconnect()
assert database_client.connected is True
assert database_client.server_ip_address is not None
def test_disconnect(database_client_on_computer):
"""Database client should set connected to False and remove the database server ip address."""
database_client, computer = database_client_on_computer
database_client.connected = True
assert database_client.operating_state is ApplicationOperatingState.RUNNING
assert database_client.server_ip_address is not None
database_client.disconnect()
assert database_client.connected is False
assert database_client.server_ip_address is None
def test_query_when_client_is_closed(database_client_on_computer):
"""Database client should return False when it is not running."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
assert database_client.query(sql="test") is False
def test_query_failed_reattempt(database_client_on_computer):
"""Database client query should return False if the reattempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test", is_reattempt=True) is False
def test_query_fail_to_connect(database_client_on_computer):
"""Database client query should return False if the connect attempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test") is False
def test_client_receives_response_when_closed(database_client_on_computer):
"""Database client receive should return False when it is closed."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
database_client.receive(payload={}, session_id="")

View File

@@ -1,39 +1,66 @@
from typing import Tuple
import pytest import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.applications.web_browser import WebBrowser
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def web_client() -> Computer: def web_browser() -> WebBrowser:
node = Computer( computer = Computer(
hostname="web_client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" hostname="web_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
) )
return node # Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.run()
assert web_browser.operating_state is ApplicationOperatingState.RUNNING
return web_browser
def test_create_web_client(web_client): def test_create_web_client():
assert web_client is not None computer = Computer(
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"] hostname="web_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
assert web_browser.name is "WebBrowser" assert web_browser.name is "WebBrowser"
assert web_browser.port is Port.HTTP assert web_browser.port is Port.HTTP
assert web_browser.protocol is IPProtocol.TCP assert web_browser.protocol is IPProtocol.TCP
def test_receive_invalid_payload(web_client): def test_receive_invalid_payload(web_browser):
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
assert web_browser.receive(payload={}) is False assert web_browser.receive(payload={}) is False
def test_receive_payload(web_client): def test_receive_payload(web_browser):
payload = HttpResponsePacket(status_code=HttpStatusCode.OK) payload = HttpResponsePacket(status_code=HttpStatusCode.OK)
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
assert web_browser.latest_response is None assert web_browser.latest_response is None
web_browser.receive(payload=payload) web_browser.receive(payload=payload)
assert web_browser.latest_response is not None assert web_browser.latest_response is not None
def test_invalid_target_url(web_browser):
# none value target url
web_browser.target_url = None
assert web_browser.get_webpage() is False
def test_non_existent_target_url(web_browser):
web_browser.target_url = "http://192.168.255.255"
assert web_browser.get_webpage() is False

View File

@@ -19,11 +19,11 @@ def dm_client() -> Node:
@pytest.fixture @pytest.fixture
def dm_bot(dm_client) -> DataManipulationBot: def dm_bot(dm_client) -> DataManipulationBot:
return dm_client.software_manager.software["DataManipulationBot"] return dm_client.software_manager.software.get("DataManipulationBot")
def test_create_dm_bot(dm_client): def test_create_dm_bot(dm_client):
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"] data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot")
assert data_manipulation_bot.name == "DataManipulationBot" assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER assert data_manipulation_bot.port == Port.POSTGRES_SERVER

View File

@@ -8,7 +8,7 @@ from primaite.simulator.system.services.database.database_service import Databas
def database_server() -> Node: def database_server() -> Node:
node = Node(hostname="db_node") node = Node(hostname="db_node")
node.software_manager.install(DatabaseService) node.software_manager.install(DatabaseService)
node.software_manager.software["DatabaseService"].start() node.software_manager.software.get("DatabaseService").start()
return node return node

View File

@@ -5,28 +5,13 @@ import pytest
from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Server(
hostname="dns_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
node.software_manager.install(software_class=DNSServer)
return node
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def dns_client() -> Node: def dns_client() -> Node:
node = Computer( node = Computer(
@@ -39,24 +24,16 @@ def dns_client() -> Node:
return node return node
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.TCP
def test_create_dns_client(dns_client): def test_create_dns_client(dns_client):
assert dns_client is not None assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client_service.name is "DNSClient" assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.TCP assert dns_client_service.protocol is IPProtocol.TCP
def test_dns_client_add_domain_to_cache_when_not_running(dns_client): def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client.operating_state is NodeOperatingState.OFF assert dns_client.operating_state is NodeOperatingState.OFF
assert dns_client_service.operating_state is ServiceOperatingState.STOPPED assert dns_client_service.operating_state is ServiceOperatingState.STOPPED
@@ -69,7 +46,7 @@ def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
def test_dns_client_check_domain_exists_when_not_running(dns_client): def test_dns_client_check_domain_exists_when_not_running(dns_client):
dns_client.operating_state = NodeOperatingState.ON dns_client.operating_state = NodeOperatingState.ON
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.start() dns_client_service.start()
assert dns_client.operating_state is NodeOperatingState.ON assert dns_client.operating_state is NodeOperatingState.ON
@@ -93,22 +70,10 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client):
assert dns_client_service.check_domain_exists("test.com") is False assert dns_client_service.check_domain_exists("test.com") is False
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
# return none for an unknown domain
assert dns_server_service.dns_lookup("fake-domain.com") is None
assert dns_server_service.dns_lookup("real-domain.com") is not None
def test_dns_client_check_domain_in_cache(dns_client): def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values.""" """Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client.operating_state = NodeOperatingState.ON dns_client.operating_state = NodeOperatingState.ON
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.start() dns_client_service.start()
# add a domain to the dns client cache # add a domain to the dns client cache
@@ -118,29 +83,9 @@ def test_dns_client_check_domain_in_cache(dns_client):
assert dns_client_service.check_domain_exists("real-domain.com") is True assert dns_client_service.check_domain_exists("real-domain.com") is True
def test_dns_server_receive(dns_server):
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
is False
)
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
is True
)
dns_server_service.show()
def test_dns_client_receive(dns_client): def test_dns_client_receive(dns_client):
"""Test to make sure the DNS Client knows how to deal with request responses.""" """Test to make sure the DNS Client knows how to deal with request responses."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"] dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.receive( dns_client_service.receive(
payload=DNSPacket( payload=DNSPacket(
@@ -151,3 +96,9 @@ def test_dns_client_receive(dns_client):
# domain name should be saved to cache # domain name should be saved to cache
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12") assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")
def test_dns_client_receive_non_dns_payload(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client_service.receive(payload=None) is False

View File

@@ -0,0 +1,64 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns.dns_server import DNSServer
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Server(
hostname="dns_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
node.software_manager.install(software_class=DNSServer)
return node
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.TCP
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
# return none for an unknown domain
assert dns_server_service.dns_lookup("fake-domain.com") is None
assert dns_server_service.dns_lookup("real-domain.com") is not None
def test_dns_server_receive(dns_server):
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
is False
)
assert (
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
is True
)
dns_server_service.show()

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture(scope="function")
def ftp_client() -> Node:
node = Computer(
hostname="ftp_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
return node
def test_create_ftp_client(ftp_client):
assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.name is "FTPClient"
assert ftp_client_service.port is Port.FTP
assert ftp_client_service.protocol is IPProtocol.TCP
def test_ftp_client_store_file(ftp_client):
"""Test to make sure the FTP Client knows how to deal with request responses."""
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
response: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR,
ftp_command_args={
"dest_folder_name": "downloads",
"dest_file_name": "file.txt",
"file_size": 24,
},
packet_payload_size=24,
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client_service.receive(response)
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt")
def test_ftp_should_not_process_commands_if_service_not_running(ftp_client):
"""Method _process_ftp_command should return false if service is not running."""
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client_service.stop()
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
"""Method send_file should return false if no file to send."""
assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.operating_state is ServiceOperatingState.RUNNING
assert (
ftp_client_service.send_file(
dest_ip_address=IPv4Address("192.168.1.1"),
src_folder_name="root",
src_file_name="test.txt",
dest_folder_name="root",
dest_file_name="text.txt",
)
is False
)
def test_offline_ftp_client_receives_request(ftp_client):
"""Receive should return false if the node the ftp client is installed on is offline."""
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client.power_off()
for i in range(ftp_client.shut_down_duration + 1):
ftp_client.apply_timestep(timestep=i)
assert ftp_client.operating_state is NodeOperatingState.OFF
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
status_code=FTPStatusCode.OK,
)
assert ftp_client_service.receive(payload=payload) is False
def test_receive_should_fail_if_payload_is_not_ftp(ftp_client):
"""Receive should return false if the node the ftp client is installed on is not an FTPPacket."""
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.receive(payload=None) is False
def test_receive_should_ignore_payload_with_none_status_code(ftp_client):
"""Receive should ignore payload with no set status code to prevent infinite send/receive loops."""
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
status_code=None,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.receive(payload=payload) is False

View File

@@ -1,16 +1,13 @@
from ipaddress import IPv4Address
import pytest import pytest
from primaite.simulator.network.hardware.base import Node from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@@ -26,34 +23,14 @@ def ftp_server() -> Node:
return node return node
@pytest.fixture(scope="function")
def ftp_client() -> Node:
node = Computer(
hostname="ftp_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
return node
def test_create_ftp_server(ftp_server): def test_create_ftp_server(ftp_server):
assert ftp_server is not None assert ftp_server is not None
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"] ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service.name is "FTPServer" assert ftp_server_service.name is "FTPServer"
assert ftp_server_service.port is Port.FTP assert ftp_server_service.port is Port.FTP
assert ftp_server_service.protocol is IPProtocol.TCP assert ftp_server_service.protocol is IPProtocol.TCP
def test_create_ftp_client(ftp_client):
assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
assert ftp_client_service.name is "FTPClient"
assert ftp_client_service.port is Port.FTP
assert ftp_client_service.protocol is IPProtocol.TCP
def test_ftp_server_store_file(ftp_server): def test_ftp_server_store_file(ftp_server):
"""Test to make sure the FTP Server knows how to deal with request responses.""" """Test to make sure the FTP Server knows how to deal with request responses."""
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") is None assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
@@ -68,16 +45,34 @@ def test_ftp_server_store_file(ftp_server):
packet_payload_size=24, packet_payload_size=24,
) )
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"] ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
ftp_server_service.receive(response) ftp_server_service.receive(response)
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt")
def test_ftp_client_store_file(ftp_client): def test_ftp_server_should_send_error_if_port_arg_is_invalid(ftp_server):
"""Test to make sure the FTP Client knows how to deal with request responses.""" """Should fail if the port command receives an invalid port."""
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=None,
packet_payload_size=24,
)
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
def test_ftp_server_receives_non_ftp_packet(ftp_server):
"""Receive should return false if the service receives a non ftp packet."""
response: FTPPacket = None
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service.receive(response) is False
def test_offline_ftp_server_receives_request(ftp_server):
"""Receive should return false if the service is stopped."""
response: FTPPacket = FTPPacket( response: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR, ftp_command=FTPCommand.STOR,
ftp_command_args={ ftp_command_args={
@@ -86,10 +81,9 @@ def test_ftp_client_store_file(ftp_client):
"file_size": 24, "file_size": 24,
}, },
packet_payload_size=24, packet_payload_size=24,
status_code=FTPStatusCode.OK,
) )
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"] ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
ftp_client_service.receive(response) ftp_server_service.stop()
assert ftp_server_service.operating_state is ServiceOperatingState.STOPPED
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") assert ftp_server_service.receive(response) is False

View File

@@ -18,13 +18,13 @@ def web_server() -> Server:
hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1" hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
) )
node.software_manager.install(software_class=WebServer) node.software_manager.install(software_class=WebServer)
node.software_manager.software["WebServer"].start() node.software_manager.software.get("WebServer").start()
return node return node
def test_create_web_server(web_server): def test_create_web_server(web_server):
assert web_server is not None assert web_server is not None
web_server_service: WebServer = web_server.software_manager.software["WebServer"] web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service.name is "WebServer" assert web_server_service.name is "WebServer"
assert web_server_service.port is Port.HTTP assert web_server_service.port is Port.HTTP
assert web_server_service.protocol is IPProtocol.TCP assert web_server_service.protocol is IPProtocol.TCP
@@ -33,7 +33,7 @@ def test_create_web_server(web_server):
def test_handling_get_request_not_found_path(web_server): def test_handling_get_request_not_found_path(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path") payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path")
web_server_service: WebServer = web_server.software_manager.software["WebServer"] web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload) response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
assert response.status_code == HttpStatusCode.NOT_FOUND assert response.status_code == HttpStatusCode.NOT_FOUND
@@ -42,7 +42,7 @@ def test_handling_get_request_not_found_path(web_server):
def test_handling_get_request_home_page(web_server): def test_handling_get_request_home_page(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/") payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"] web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload) response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
assert response.status_code == HttpStatusCode.OK assert response.status_code == HttpStatusCode.OK
@@ -51,7 +51,7 @@ def test_handling_get_request_home_page(web_server):
def test_process_http_request_get(web_server): def test_process_http_request_get(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/") payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"] web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service._process_http_request(payload=payload) is True assert web_server_service._process_http_request(payload=payload) is True
@@ -59,6 +59,6 @@ def test_process_http_request_get(web_server):
def test_process_http_request_method_not_allowed(web_server): def test_process_http_request_method_not_allowed(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.DELETE, request_url="http://domain.com/") payload = HttpRequestPacket(request_method=HttpRequestMethod.DELETE, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"] web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service._process_http_request(payload=payload) is False assert web_server_service._process_http_request(payload=payload) is False