Merge 'origin/dev' into bugfix/episode-length-and-rewards
This commit is contained in:
@@ -6,6 +6,9 @@ trigger:
|
||||
- bugfix/*
|
||||
- release/*
|
||||
|
||||
pr:
|
||||
autoCancel: true
|
||||
drafts: false
|
||||
parameters:
|
||||
# https://stackoverflow.com/a/70046417
|
||||
- name: matrix
|
||||
@@ -85,6 +88,33 @@ stages:
|
||||
primaite setup
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
- task: UseDotNet@2
|
||||
displayName: 'Install dotnet dependencies'
|
||||
inputs:
|
||||
packageType: 'sdk'
|
||||
version: '2.1.x'
|
||||
|
||||
- script: |
|
||||
pytest -n auto
|
||||
displayName: 'Run tests'
|
||||
coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml
|
||||
coverage xml -o coverage.xml -i
|
||||
coverage html -d htmlcov -i
|
||||
displayName: 'Run tests and code coverage'
|
||||
|
||||
- task: PublishTestResults@2
|
||||
condition: succeededOrFailed()
|
||||
inputs:
|
||||
testRunner: JUnit
|
||||
testResultsFiles: 'junit/**.xml'
|
||||
testRunTitle: 'Publish test results'
|
||||
|
||||
- publish: $(System.DefaultWorkingDirectory)/htmlcov/
|
||||
# publish the html report - so we can debug the coverage if needed
|
||||
condition: ${{ item.every_time }} # should only be run once
|
||||
artifact: coverage_report
|
||||
|
||||
- task: PublishCodeCoverageResults@2
|
||||
# publish the code coverage so it can be viewed in the run coverage page
|
||||
condition: ${{ item.every_time }} # should only be run once
|
||||
inputs:
|
||||
codeCoverageTool: Cobertura
|
||||
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -37,6 +37,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
junit/
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -54,7 +54,7 @@ Example
|
||||
)
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
data_manipulation_bot.run()
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ See :ref:`Node Start up and Shut down`
|
||||
|
||||
node.software_manager.install(WebServer)
|
||||
|
||||
web_server: WebServer = node.software_manager.software["WebServer"]
|
||||
web_server: WebServer = node.software_manager.software.get("WebServer")
|
||||
assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install
|
||||
|
||||
node.power_off()
|
||||
|
||||
@@ -424,7 +424,7 @@ class NetworkACLAddRuleAction(AbstractAction):
|
||||
elif permission == 2:
|
||||
permission_str = "DENY"
|
||||
else:
|
||||
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
@@ -264,7 +264,7 @@ class FolderObservation(AbstractObservation):
|
||||
while len(self.files) > num_files_per_folder:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
@@ -438,7 +438,7 @@ class NodeObservation(AbstractObservation):
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
# truncate service list
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
@@ -448,7 +448,7 @@ class NodeObservation(AbstractObservation):
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.nics: List[NicObservation] = nics
|
||||
while len(self.nics) < num_nics_per_node:
|
||||
@@ -456,7 +456,7 @@ class NodeObservation(AbstractObservation):
|
||||
while len(self.nics) > num_nics_per_node:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
|
||||
"found in reward config."
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = game.ref_map_nodes[node_ref]
|
||||
service_uuid = game.ref_map_services[service_ref]
|
||||
@@ -219,7 +219,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
|
||||
" found in the simulator."
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
return DummyReward() # TODO: consider erroring here as well
|
||||
|
||||
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
|
||||
|
||||
@@ -252,6 +252,6 @@ class SimComponent(BaseModel):
|
||||
def parent(self, new_parent: Union["SimComponent", None]) -> None:
|
||||
if self._parent and new_parent:
|
||||
msg = f"Overwriting parent of {self.uuid}. Old parent: {self._parent.uuid}, New parent: {new_parent.uuid}"
|
||||
_LOGGER.warn(msg)
|
||||
_LOGGER.warning(msg)
|
||||
raise RuntimeWarning(msg)
|
||||
self._parent = new_parent
|
||||
|
||||
@@ -72,7 +72,7 @@ class Account(SimComponent):
|
||||
"num_group_changes": self.num_group_changes,
|
||||
"username": self.username,
|
||||
"password": self.password,
|
||||
"account_type": self.account_type.name,
|
||||
"account_type": self.account_type.value,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -53,7 +53,10 @@ class FileSystem(SimComponent):
|
||||
original_folder_uuids = self._original_state["original_folder_uuids"]
|
||||
for uuid in original_folder_uuids:
|
||||
if uuid in self.deleted_folders:
|
||||
self.folders[uuid] = self.deleted_folders.pop(uuid)
|
||||
folder = self.deleted_folders[uuid]
|
||||
self.deleted_folders.pop(uuid)
|
||||
self.folders[uuid] = folder
|
||||
self._folders_by_name[folder.name] = folder
|
||||
|
||||
# Clear any other deleted folders that aren't original (have been created by agent)
|
||||
self.deleted_folders.clear()
|
||||
@@ -62,7 +65,9 @@ class FileSystem(SimComponent):
|
||||
current_folder_uuids = list(self.folders.keys())
|
||||
for uuid in current_folder_uuids:
|
||||
if uuid not in original_folder_uuids:
|
||||
folder = self.folders[uuid]
|
||||
self.folders.pop(uuid)
|
||||
self._folders_by_name.pop(folder.name)
|
||||
|
||||
# Now reset all remaining folders
|
||||
for folder in self.folders.values():
|
||||
|
||||
@@ -75,7 +75,10 @@ class Folder(FileSystemItemABC):
|
||||
original_file_uuids = self._original_state["original_file_uuids"]
|
||||
for uuid in original_file_uuids:
|
||||
if uuid in self.deleted_files:
|
||||
self.files[uuid] = self.deleted_files.pop(uuid)
|
||||
file = self.deleted_files[uuid]
|
||||
self.deleted_files.pop(uuid)
|
||||
self.files[uuid] = file
|
||||
self._files_by_name[file.name] = file
|
||||
|
||||
# Clear any other deleted files that aren't original (have been created by agent)
|
||||
self.deleted_files.clear()
|
||||
@@ -84,7 +87,9 @@ class Folder(FileSystemItemABC):
|
||||
current_file_uuids = list(self.files.keys())
|
||||
for uuid in current_file_uuids:
|
||||
if uuid not in original_file_uuids:
|
||||
file = self.files[uuid]
|
||||
self.files.pop(uuid)
|
||||
self._files_by_name.pop(file.name)
|
||||
|
||||
# Now reset all remaining files
|
||||
for file in self.files.values():
|
||||
|
||||
@@ -51,14 +51,22 @@ def client_server_routed() -> Network:
|
||||
|
||||
# Client 1
|
||||
client_1 = Computer(
|
||||
hostname="client_1", ip_address="192.168.2.2", subnet_mask="255.255.255.0", default_gateway="192.168.2.1"
|
||||
hostname="client_1",
|
||||
ip_address="192.168.2.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.2.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
client_1.power_on()
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
|
||||
# Server 1
|
||||
server_1 = Server(
|
||||
hostname="server_1", ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
hostname="server_1",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
server_1.power_on()
|
||||
network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
|
||||
@@ -139,7 +147,7 @@ def arcd_uc2_network() -> Network:
|
||||
client_1.power_on()
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
db_manipulation_bot.configure(
|
||||
server_ip_address=IPv4Address("192.168.1.14"),
|
||||
payload="DELETE",
|
||||
@@ -157,7 +165,7 @@ def arcd_uc2_network() -> Network:
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
client_2.power_on()
|
||||
web_browser = client_2.software_manager.software["WebBrowser"]
|
||||
web_browser = client_2.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
|
||||
|
||||
@@ -241,7 +249,7 @@ def arcd_uc2_network() -> Network:
|
||||
# noqa
|
||||
]
|
||||
database_server.software_manager.install(DatabaseService)
|
||||
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
|
||||
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
|
||||
database_service.start()
|
||||
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
|
||||
database_service._process_sql(ddl, None) # noqa
|
||||
@@ -260,7 +268,7 @@ def arcd_uc2_network() -> Network:
|
||||
web_server.power_on()
|
||||
web_server.software_manager.install(DatabaseClient)
|
||||
|
||||
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
database_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
|
||||
database_client.run()
|
||||
@@ -269,7 +277,7 @@ def arcd_uc2_network() -> Network:
|
||||
web_server.software_manager.install(WebServer)
|
||||
|
||||
# register the web_server to a domain
|
||||
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa
|
||||
dns_server_service: DNSServer = domain_controller.software_manager.software.get("DNSServer") # noqa
|
||||
dns_server_service.dns_register("arcd.com", web_server.ip_address)
|
||||
|
||||
# Backup Server
|
||||
|
||||
@@ -5,6 +5,8 @@ def convert_bytes_to_megabits(B: Union[int, float]) -> float: # noqa - Keep it
|
||||
"""
|
||||
Convert Bytes (file size) to Megabits (data transfer).
|
||||
|
||||
Technically Mebibits - but for simplicity sake, we'll call it megabit
|
||||
|
||||
:param B: The file size in Bytes.
|
||||
:return: File bits to transfer in Megabits.
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,8 @@ class DatabaseClient(Application):
|
||||
|
||||
if not self.connected:
|
||||
return self._connect(self.server_ip_address, self.server_password)
|
||||
return False
|
||||
# already connected
|
||||
return True
|
||||
|
||||
def _connect(
|
||||
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
|
||||
@@ -107,7 +108,7 @@ class DatabaseClient(Application):
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from the Database Service."""
|
||||
if self.connected and self.operating_state.RUNNING:
|
||||
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
|
||||
@@ -186,6 +187,9 @@ class DatabaseClient(Application):
|
||||
:param session_id: The session id the payload relates to.
|
||||
:return: True.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
self.connected = payload["response"] == True
|
||||
|
||||
@@ -99,7 +99,7 @@ class WebBrowser(Application):
|
||||
return False
|
||||
|
||||
# get the IP address of the domain name via DNS
|
||||
dns_client: DNSClient = self.software_manager.software["DNSClient"]
|
||||
dns_client: DNSClient = self.software_manager.software.get("DNSClient")
|
||||
domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname)
|
||||
|
||||
# if domain does not exist, the request fails
|
||||
|
||||
@@ -80,7 +80,7 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
|
||||
|
||||
# send backup copy of database file to FTP server
|
||||
response = ftp_client_service.send_file(
|
||||
@@ -104,7 +104,7 @@ class DatabaseService(Service):
|
||||
return False
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
|
||||
|
||||
# retrieve backup file from backup server
|
||||
response = ftp_client_service.request_file(
|
||||
|
||||
@@ -8,7 +8,6 @@ from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -53,8 +52,7 @@ class FTPClient(FTPServiceABC):
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
# if client service is down, return error
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.error("FTP Client is not running")
|
||||
if not self._can_perform_action():
|
||||
payload.status_code = FTPStatusCode.ERROR
|
||||
return payload
|
||||
|
||||
@@ -81,8 +79,7 @@ class FTPClient(FTPServiceABC):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
# make sure the service is running before attempting
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.error(f"FTPClient not running for {self.sys_log.hostname}")
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
# normally FTP will choose a random port for the transfer, but using the FTP command port will do for now
|
||||
@@ -282,8 +279,11 @@ class FTPClient(FTPServiceABC):
|
||||
This helps prevent an FTP request loop - FTP client and servers can exist on
|
||||
the same node.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if payload.status_code is None:
|
||||
self.sys_log.error(f"FTP Server could not be found - Error Code: {payload.status_code.value}")
|
||||
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
|
||||
return False
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
@@ -6,7 +6,6 @@ from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPS
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -58,8 +57,7 @@ class FTPServer(FTPServiceABC):
|
||||
payload.status_code = FTPStatusCode.ERROR
|
||||
|
||||
# if server service is down, return error
|
||||
if self.operating_state != ServiceOperatingState.RUNNING:
|
||||
self.sys_log.error("FTP Server not running")
|
||||
if not self._can_perform_action():
|
||||
return payload
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
|
||||
@@ -95,6 +93,9 @@ class FTPServer(FTPServiceABC):
|
||||
self.sys_log.error(f"{payload} is not an FTP packet")
|
||||
return False
|
||||
|
||||
if not super().receive(payload=payload, session_id=session_id, **kwargs):
|
||||
return False
|
||||
|
||||
"""
|
||||
Ignore ftp payload if status code is defined.
|
||||
|
||||
@@ -102,9 +103,6 @@ class FTPServer(FTPServiceABC):
|
||||
prevents an FTP request loop - FTP client and servers can exist on
|
||||
the same node.
|
||||
"""
|
||||
if not super().receive(payload=payload, session_id=session_id, **kwargs):
|
||||
return False
|
||||
|
||||
if payload.status_code is not None:
|
||||
return False
|
||||
|
||||
|
||||
@@ -109,8 +109,8 @@ class Service(IOSoftware):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["operating_state"] = self.operating_state.value
|
||||
state["health_state_actual"] = self.health_state_actual
|
||||
state["health_state_visible"] = self.health_state_visible
|
||||
state["health_state_actual"] = self.health_state_actual.value
|
||||
state["health_state_visible"] = self.health_state_visible.value
|
||||
return state
|
||||
|
||||
def stop(self) -> None:
|
||||
|
||||
@@ -119,7 +119,7 @@ class WebServer(Service):
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
|
||||
# get all users
|
||||
if db_client.query("SELECT"):
|
||||
# query succeeded
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -12,6 +12,11 @@ from primaite.session.session import PrimaiteSession
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.networks import arcd_uc2_network
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -29,7 +34,7 @@ from primaite import PRIMAITE_PATHS
|
||||
|
||||
# PrimAITE v3 stuff
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
|
||||
|
||||
class TestService(Service):
|
||||
@@ -122,3 +127,110 @@ def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
|
||||
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
|
||||
config_path = request.param[0]
|
||||
return TempPrimaiteSession.from_config(config_path=config_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_server() -> Tuple[Computer, Server]:
|
||||
# Create Computer
|
||||
computer: Computer = Computer(
|
||||
hostname="test_computer",
|
||||
ip_address="192.168.0.1",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
|
||||
# Create Server
|
||||
server = Server(
|
||||
hostname="server", ip_address="192.168.0.2", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
|
||||
)
|
||||
|
||||
# Connect Computer and Server
|
||||
computer_nic = computer.nics[next(iter(computer.nics))]
|
||||
server_nic = server.nics[next(iter(server.nics))]
|
||||
link = Link(endpoint_a=computer_nic, endpoint_b=server_nic)
|
||||
|
||||
# Should be linked
|
||||
assert link.is_up
|
||||
|
||||
return computer, server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def example_network() -> Network:
|
||||
"""
|
||||
Create the network used for testing.
|
||||
|
||||
Should only contain the nodes and links.
|
||||
This would act as the base network and services and applications are installed in the relevant test file,
|
||||
|
||||
-------------- --------------
|
||||
| client_1 |----- ----| server_1 |
|
||||
-------------- | -------------- -------------- -------------- | --------------
|
||||
------| switch_1 |------| router_1 |------| switch_2 |------
|
||||
-------------- | -------------- -------------- -------------- | --------------
|
||||
| client_2 |---- ----| server_2 |
|
||||
-------------- --------------
|
||||
"""
|
||||
network = Network()
|
||||
|
||||
# Router 1
|
||||
router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON)
|
||||
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
|
||||
|
||||
# Switch 1
|
||||
switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
|
||||
network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])
|
||||
router_1.enable_port(1)
|
||||
|
||||
# Switch 2
|
||||
switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON)
|
||||
network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])
|
||||
router_1.enable_port(2)
|
||||
|
||||
# Client 1
|
||||
client_1 = Computer(
|
||||
hostname="client_1",
|
||||
ip_address="192.168.10.21",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
|
||||
|
||||
# Client 2
|
||||
client_2 = Computer(
|
||||
hostname="client_2",
|
||||
ip_address="192.168.10.22",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.10.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
|
||||
|
||||
# Domain Controller
|
||||
server_1 = Server(
|
||||
hostname="server_1",
|
||||
ip_address="192.168.1.10",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
|
||||
network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
|
||||
|
||||
# Database Server
|
||||
server_2 = Server(
|
||||
hostname="server_2",
|
||||
ip_address="192.168.1.14",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
network.connect(endpoint_b=server_2.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
|
||||
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
|
||||
|
||||
return network
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
@@ -10,6 +11,7 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="no way of currently testing this")
|
||||
def test_sb3_compatibility():
|
||||
"""Test that the Gymnasium environment can be used with an SB3 agent."""
|
||||
with open(example_config_path(), "r") as f:
|
||||
|
||||
@@ -11,6 +11,7 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="no way of currently testing this")
|
||||
class TestPrimaiteSession:
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_creating_session(self, temp_primaite_session):
|
||||
|
||||
@@ -8,13 +8,13 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor
|
||||
def test_data_manipulation(uc2_network):
|
||||
"""Tests the UC2 data manipulation scenario end-to-end. Is a work in progress."""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
|
||||
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
|
||||
database_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = database_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = database_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
db_service.backup_database()
|
||||
|
||||
|
||||
@@ -16,3 +16,9 @@ def test_link_up():
|
||||
assert nic_a.enabled
|
||||
assert nic_b.enabled
|
||||
assert link.is_up
|
||||
|
||||
|
||||
def test_ping_between_computer_and_server(client_server):
|
||||
computer, server = client_server
|
||||
|
||||
assert computer.ping(target_ip_address=server.nics[next(iter(server.nics))].ip_address)
|
||||
|
||||
@@ -2,6 +2,28 @@ import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.networks import client_server_routed
|
||||
|
||||
|
||||
def test_network(example_network):
|
||||
network: Network = example_network
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
server_1: Server = network.get_node_by_hostname("server_1")
|
||||
server_2: Server = network.get_node_by_hostname("server_2")
|
||||
|
||||
assert client_1.ping(client_2.ethernet_port[1].ip_address)
|
||||
assert client_2.ping(client_1.ethernet_port[1].ip_address)
|
||||
|
||||
assert server_1.ping(server_2.ethernet_port[1].ip_address)
|
||||
assert server_2.ping(server_1.ethernet_port[1].ip_address)
|
||||
|
||||
assert client_1.ping(server_1.ethernet_port[1].ip_address)
|
||||
assert client_2.ping(server_1.ethernet_port[1].ip_address)
|
||||
assert client_1.ping(server_2.ethernet_port[1].ip_address)
|
||||
assert client_2.ping(server_2.ethernet_port[1].ip_address)
|
||||
|
||||
|
||||
def test_adding_removing_nodes():
|
||||
|
||||
@@ -18,7 +18,7 @@ def populated_node(application_class) -> Tuple[Application, Computer]:
|
||||
)
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app = computer.software_manager.software["TestApplication"]
|
||||
app = computer.software_manager.software.get("TestApplication")
|
||||
app.run()
|
||||
|
||||
return app, computer
|
||||
@@ -35,7 +35,7 @@ def test_service_on_offline_node(application_class):
|
||||
)
|
||||
computer.software_manager.install(application_class)
|
||||
|
||||
app: Application = computer.software_manager.software["TestApplication"]
|
||||
app: Application = computer.software_manager.software.get("TestApplication")
|
||||
|
||||
computer.power_off()
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
def test_database_client_server_connection(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
@@ -23,10 +23,10 @@ def test_database_client_server_connection(uc2_network):
|
||||
|
||||
def test_database_client_server_correct_password(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
db_client.disconnect()
|
||||
|
||||
@@ -40,10 +40,10 @@ def test_database_client_server_correct_password(uc2_network):
|
||||
|
||||
def test_database_client_server_incorrect_password(uc2_network):
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
db_client.disconnect()
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
|
||||
@@ -56,7 +56,7 @@ def test_database_client_server_incorrect_password(uc2_network):
|
||||
def test_database_client_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
assert db_client.connected
|
||||
|
||||
@@ -66,13 +66,13 @@ def test_database_client_query(uc2_network):
|
||||
def test_create_database_backup(uc2_network):
|
||||
"""Run the backup_database method and check if the FTP server has the relevant file."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
# back up should be created
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
|
||||
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
|
||||
ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
|
||||
|
||||
# backup file should exist in the backup server
|
||||
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
|
||||
@@ -81,7 +81,7 @@ def test_create_database_backup(uc2_network):
|
||||
def test_restore_backup(uc2_network):
|
||||
"""Run the restore_backup method and check if the backup is properly restored."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
# create a back up
|
||||
assert db_service.backup_database() is True
|
||||
@@ -100,13 +100,13 @@ def test_restore_backup(uc2_network):
|
||||
def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 404 when db server is offline."""
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert db_server.operating_state is NodeOperatingState.ON
|
||||
assert db_service.operating_state is ServiceOperatingState.RUNNING
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
assert db_client.connected
|
||||
|
||||
assert db_client.query("SELECT") is True
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
@@ -6,12 +11,31 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
def test_dns_client_server(uc2_network):
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_client_and_dns_server(client_server) -> Tuple[DNSClient, Computer, DNSServer, Server]:
|
||||
computer, server = client_server
|
||||
|
||||
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
|
||||
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
|
||||
# Install DNS Client on computer
|
||||
computer.software_manager.install(DNSClient)
|
||||
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
|
||||
dns_client.start()
|
||||
# set server as DNS Server
|
||||
dns_client.dns_server = IPv4Address(server.nics.get(next(iter(server.nics))).ip_address)
|
||||
|
||||
# Install DNS Server on server
|
||||
server.software_manager.install(DNSServer)
|
||||
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
|
||||
dns_server.start()
|
||||
# register arcd.com as a domain
|
||||
dns_server.dns_register(
|
||||
domain_name="arcd.com", domain_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address)
|
||||
)
|
||||
|
||||
return dns_client, computer, dns_server, server
|
||||
|
||||
|
||||
def test_dns_client_server(dns_client_and_dns_server):
|
||||
dns_client, computer, dns_server, server = dns_client_and_dns_server
|
||||
|
||||
assert dns_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert dns_server.operating_state == ServiceOperatingState.RUNNING
|
||||
@@ -29,12 +53,8 @@ def test_dns_client_server(uc2_network):
|
||||
assert len(dns_client.dns_cache) == 1
|
||||
|
||||
|
||||
def test_dns_client_requests_offline_dns_server(uc2_network):
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
|
||||
|
||||
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
|
||||
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
|
||||
def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server):
|
||||
dns_client, computer, dns_server, server = dns_client_and_dns_server
|
||||
|
||||
assert dns_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert dns_server.operating_state == ServiceOperatingState.RUNNING
|
||||
@@ -48,12 +68,12 @@ def test_dns_client_requests_offline_dns_server(uc2_network):
|
||||
assert len(dns_client.dns_cache) == 1
|
||||
dns_client.dns_cache = {}
|
||||
|
||||
domain_controller.power_off()
|
||||
server.power_off()
|
||||
|
||||
for i in range(domain_controller.shut_down_duration + 1):
|
||||
uc2_network.apply_timestep(timestep=i)
|
||||
for i in range(server.shut_down_duration + 1):
|
||||
server.apply_timestep(timestep=i)
|
||||
|
||||
assert domain_controller.operating_state == NodeOperatingState.OFF
|
||||
assert server.operating_state == NodeOperatingState.OFF
|
||||
assert dns_server.operating_state == ServiceOperatingState.STOPPED
|
||||
|
||||
# this time it should not cache because dns server is not online
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
@@ -7,18 +10,31 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
def test_ftp_client_store_file_in_server(uc2_network):
|
||||
@pytest.fixture(scope="function")
|
||||
def ftp_client_and_ftp_server(client_server) -> Tuple[FTPClient, Computer, FTPServer, Server]:
|
||||
computer, server = client_server
|
||||
|
||||
# Install FTP Client service on computer
|
||||
computer.software_manager.install(FTPClient)
|
||||
ftp_client: FTPClient = computer.software_manager.software.get("FTPClient")
|
||||
ftp_client.start()
|
||||
|
||||
# Install FTP Server service on server
|
||||
server.software_manager.install(FTPServer)
|
||||
ftp_server: FTPServer = server.software_manager.software.get("FTPServer")
|
||||
ftp_server.start()
|
||||
|
||||
return ftp_client, computer, ftp_server, server
|
||||
|
||||
|
||||
def test_ftp_client_store_file_in_server(ftp_client_and_ftp_server):
|
||||
"""
|
||||
Test checks to see if the client is able to store files in the backup server.
|
||||
"""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
|
||||
|
||||
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
|
||||
ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"]
|
||||
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
|
||||
|
||||
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server.operating_state == ServiceOperatingState.RUNNING
|
||||
|
||||
# create file on ftp client
|
||||
ftp_client.file_system.create_file(file_name="test_file.txt")
|
||||
@@ -28,61 +44,53 @@ def test_ftp_client_store_file_in_server(uc2_network):
|
||||
src_file_name="test_file.txt",
|
||||
dest_folder_name="client_1_backup",
|
||||
dest_file_name="test_file.txt",
|
||||
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address,
|
||||
dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
|
||||
)
|
||||
|
||||
assert ftp_server_service.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt")
|
||||
assert ftp_server.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt")
|
||||
|
||||
|
||||
def test_ftp_client_retrieve_file_from_server(uc2_network):
|
||||
def test_ftp_client_retrieve_file_from_server(ftp_client_and_ftp_server):
|
||||
"""
|
||||
Test checks to see if the client is able to retrieve files from the backup server.
|
||||
"""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
|
||||
|
||||
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
|
||||
ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"]
|
||||
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
|
||||
|
||||
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server.operating_state == ServiceOperatingState.RUNNING
|
||||
|
||||
# create file on ftp server
|
||||
ftp_server_service.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
|
||||
ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
|
||||
|
||||
assert ftp_client.request_file(
|
||||
src_folder_name="file_share",
|
||||
src_file_name="test_file.txt",
|
||||
dest_folder_name="downloads",
|
||||
dest_file_name="test_file.txt",
|
||||
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address,
|
||||
dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
|
||||
)
|
||||
|
||||
# client should have retrieved the file
|
||||
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="test_file.txt")
|
||||
|
||||
|
||||
def test_ftp_client_tries_to_connect_to_offline_server(uc2_network):
|
||||
def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server):
|
||||
"""Test checks to make sure that the client can't do anything when the server is offline."""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
|
||||
|
||||
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
|
||||
ftp_server_service: FTPServer = backup_server.software_manager.software["FTPServer"]
|
||||
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
|
||||
|
||||
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server_service.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server.operating_state == ServiceOperatingState.RUNNING
|
||||
|
||||
# create file on ftp server
|
||||
ftp_server_service.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
|
||||
ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
|
||||
|
||||
backup_server.power_off()
|
||||
server.power_off()
|
||||
|
||||
for i in range(backup_server.shut_down_duration + 1):
|
||||
uc2_network.apply_timestep(timestep=i)
|
||||
for i in range(server.shut_down_duration + 1):
|
||||
server.apply_timestep(timestep=i)
|
||||
|
||||
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
|
||||
assert ftp_server_service.operating_state == ServiceOperatingState.STOPPED
|
||||
assert ftp_server.operating_state == ServiceOperatingState.STOPPED
|
||||
|
||||
assert (
|
||||
ftp_client.request_file(
|
||||
@@ -90,7 +98,7 @@ def test_ftp_client_tries_to_connect_to_offline_server(uc2_network):
|
||||
src_file_name="test_file.txt",
|
||||
dest_folder_name="downloads",
|
||||
dest_file_name="test_file.txt",
|
||||
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address,
|
||||
dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ def populated_node(
|
||||
)
|
||||
server.software_manager.install(service_class)
|
||||
|
||||
service = server.software_manager.software["TestService"]
|
||||
service = server.software_manager.software.get("TestService")
|
||||
service.start()
|
||||
|
||||
return server, service
|
||||
@@ -34,7 +34,7 @@ def test_service_on_offline_node(service_class):
|
||||
)
|
||||
computer.software_manager.install(service_class)
|
||||
|
||||
service: Service = computer.software_manager.software["TestService"]
|
||||
service: Service = computer.software_manager.software.get("TestService")
|
||||
|
||||
computer.power_off()
|
||||
|
||||
|
||||
@@ -1,104 +1,118 @@
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.protocols.http import HttpStatusCode
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
|
||||
def test_web_page_home_page(uc2_network):
|
||||
"""Test to see if the browser is able to open the main page of the web server."""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
|
||||
web_client.run()
|
||||
web_client.target_url = "http://arcd.com/"
|
||||
assert web_client.operating_state == ApplicationOperatingState.RUNNING
|
||||
@pytest.fixture(scope="function")
|
||||
def web_client_and_web_server(client_server) -> Tuple[WebBrowser, Computer, WebServer, Server]:
|
||||
computer, server = client_server
|
||||
|
||||
assert web_client.get_webpage() is True
|
||||
# Install Web Browser on computer
|
||||
computer.software_manager.install(WebBrowser)
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
web_browser.run()
|
||||
|
||||
# latest reponse should have status code 200
|
||||
assert web_client.latest_response is not None
|
||||
assert web_client.latest_response.status_code == HttpStatusCode.OK
|
||||
# Install DNS Client service on computer
|
||||
computer.software_manager.install(DNSClient)
|
||||
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
|
||||
# set dns server
|
||||
dns_client.dns_server = server.nics[next(iter(server.nics))].ip_address
|
||||
|
||||
# Install Web Server service on server
|
||||
server.software_manager.install(WebServer)
|
||||
web_server_service: WebServer = server.software_manager.software.get("WebServer")
|
||||
web_server_service.start()
|
||||
|
||||
# Install DNS Server service on server
|
||||
server.software_manager.install(DNSServer)
|
||||
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
|
||||
# register arcd.com to DNS
|
||||
dns_server.dns_register(domain_name="arcd.com", domain_ip_address=server.nics[next(iter(server.nics))].ip_address)
|
||||
|
||||
return web_browser, computer, web_server_service, server
|
||||
|
||||
|
||||
def test_web_page_get_users_page_request_with_domain_name(uc2_network):
|
||||
def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_server):
|
||||
"""Test to see if the client can handle requests with domain names"""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
|
||||
web_client.run()
|
||||
assert web_client.operating_state == ApplicationOperatingState.RUNNING
|
||||
web_client.target_url = "http://arcd.com/users/"
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
assert web_client.get_webpage() is True
|
||||
web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_browser_app.get_webpage() is True
|
||||
|
||||
# latest response should have status code 200
|
||||
assert web_client.latest_response is not None
|
||||
assert web_client.latest_response.status_code == HttpStatusCode.OK
|
||||
assert web_browser_app.latest_response is not None
|
||||
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
|
||||
|
||||
|
||||
def test_web_page_get_users_page_request_with_ip_address(uc2_network):
|
||||
def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_server):
|
||||
"""Test to see if the client can handle requests that use ip_address."""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
|
||||
web_client.run()
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
|
||||
web_browser_app.target_url = f"http://{web_server_ip}/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address
|
||||
web_client.target_url = f"http://{web_server_ip}/users/"
|
||||
assert web_client.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_client.get_webpage() is True
|
||||
assert web_browser_app.get_webpage() is True
|
||||
|
||||
# latest response should have status code 200
|
||||
assert web_client.latest_response is not None
|
||||
assert web_client.latest_response.status_code == HttpStatusCode.OK
|
||||
assert web_browser_app.latest_response is not None
|
||||
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
|
||||
|
||||
|
||||
def test_web_page_request_from_shut_down_server(uc2_network):
|
||||
def test_web_page_request_from_shut_down_server(web_client_and_web_server):
|
||||
"""Test to see that the web server does not respond when the server is off."""
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
|
||||
web_client.run()
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_client.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_client.get_webpage("http://arcd.com/users/") is True
|
||||
assert web_browser_app.get_webpage() is True
|
||||
|
||||
# latest response should have status code 200
|
||||
assert web_client.latest_response.status_code == HttpStatusCode.OK
|
||||
assert web_browser_app.latest_response is not None
|
||||
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
|
||||
|
||||
web_server.power_off()
|
||||
server.power_off()
|
||||
|
||||
for i in range(web_server.shut_down_duration + 1):
|
||||
uc2_network.apply_timestep(timestep=i)
|
||||
server.power_off()
|
||||
|
||||
for i in range(server.shut_down_duration + 1):
|
||||
server.apply_timestep(timestep=i)
|
||||
|
||||
# node should be off
|
||||
assert web_server.operating_state is NodeOperatingState.OFF
|
||||
assert server.operating_state is NodeOperatingState.OFF
|
||||
|
||||
assert web_client.get_webpage("http://arcd.com/users/") is False
|
||||
assert web_client.latest_response.status_code == HttpStatusCode.NOT_FOUND
|
||||
assert web_browser_app.get_webpage() is False
|
||||
assert web_browser_app.latest_response.status_code == HttpStatusCode.NOT_FOUND
|
||||
|
||||
|
||||
def test_web_page_request_from_closed_web_browser(uc2_network):
|
||||
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
|
||||
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
|
||||
web_client.run()
|
||||
def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
|
||||
web_browser_app, computer, web_server_service, server = web_client_and_web_server
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
|
||||
assert web_client.operating_state == ApplicationOperatingState.RUNNING
|
||||
|
||||
assert web_client.get_webpage("http://arcd.com/users/") is True
|
||||
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
|
||||
web_browser_app.target_url = f"http://arcd.com/"
|
||||
assert web_browser_app.get_webpage() is True
|
||||
|
||||
# latest response should have status code 200
|
||||
assert web_client.latest_response.status_code == HttpStatusCode.OK
|
||||
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
|
||||
|
||||
web_client.close()
|
||||
web_browser_app.close()
|
||||
|
||||
# node should be off
|
||||
assert web_client.operating_state is ApplicationOperatingState.CLOSED
|
||||
assert web_browser_app.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
assert web_client.get_webpage("http://arcd.com/users/") is False
|
||||
assert web_browser_app.get_webpage() is False
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Link
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def web_client_web_server_database(example_network) -> Tuple[Computer, Server, Server]:
|
||||
# add rules to network router
|
||||
router_1: Router = example_network.get_node_by_hostname("router_1")
|
||||
router_1.acl.add_rule(
|
||||
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
|
||||
)
|
||||
|
||||
# Allow DNS requests
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
|
||||
|
||||
# Allow FTP requests
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2)
|
||||
|
||||
# Open port 80 for web server
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
|
||||
|
||||
# Create Computer
|
||||
computer: Computer = example_network.get_node_by_hostname("client_1")
|
||||
|
||||
# Create Web Server
|
||||
web_server: Server = example_network.get_node_by_hostname("server_1")
|
||||
|
||||
# Create Database Server
|
||||
db_server = example_network.get_node_by_hostname("server_2")
|
||||
|
||||
# Get the NICs
|
||||
computer_nic = computer.nics[next(iter(computer.nics))]
|
||||
server_nic = web_server.nics[next(iter(web_server.nics))]
|
||||
db_server_nic = db_server.nics[next(iter(db_server.nics))]
|
||||
|
||||
# Connect Computer and Server
|
||||
link_computer_server = Link(endpoint_a=computer_nic, endpoint_b=server_nic)
|
||||
# Should be linked
|
||||
assert link_computer_server.is_up
|
||||
|
||||
# Connect database server and web server
|
||||
link_server_db = Link(endpoint_a=server_nic, endpoint_b=db_server_nic)
|
||||
# Should be linked
|
||||
assert link_computer_server.is_up
|
||||
assert link_server_db.is_up
|
||||
|
||||
# Install DatabaseService on db server
|
||||
db_server.software_manager.install(DatabaseService)
|
||||
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
|
||||
db_service.start()
|
||||
|
||||
# Install Web Browser on computer
|
||||
computer.software_manager.install(WebBrowser)
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
web_browser.target_url = "http://arcd.com/users/"
|
||||
web_browser.run()
|
||||
|
||||
# Install DNS Client service on computer
|
||||
computer.software_manager.install(DNSClient)
|
||||
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
|
||||
# set dns server
|
||||
dns_client.dns_server = web_server.nics[next(iter(web_server.nics))].ip_address
|
||||
|
||||
# Install Web Server service on web server
|
||||
web_server.software_manager.install(WebServer)
|
||||
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
|
||||
web_server_service.start()
|
||||
|
||||
# Install DNS Server service on web server
|
||||
web_server.software_manager.install(DNSServer)
|
||||
dns_server: DNSServer = web_server.software_manager.software.get("DNSServer")
|
||||
# register arcd.com to DNS
|
||||
dns_server.dns_register(
|
||||
domain_name="arcd.com", domain_ip_address=web_server.nics[next(iter(web_server.nics))].ip_address
|
||||
)
|
||||
|
||||
# Install DatabaseClient service on web server
|
||||
web_server.software_manager.install(DatabaseClient)
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
db_client.server_ip_address = IPv4Address(db_server_nic.ip_address) # set IP address of Database Server
|
||||
db_client.run()
|
||||
assert dns_client.check_domain_exists("arcd.com")
|
||||
assert db_client.connect()
|
||||
|
||||
return computer, web_server, db_server
|
||||
|
||||
|
||||
def test_web_client_requests_users(web_client_web_server_database):
|
||||
computer, web_server, db_server = web_client_web_server_database
|
||||
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
|
||||
assert web_browser.get_webpage()
|
||||
@@ -1,18 +1,140 @@
|
||||
"""Test the account module of the simulator."""
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.domain.account import Account, AccountType
|
||||
|
||||
|
||||
def test_account_serialise():
|
||||
@pytest.fixture(scope="function")
|
||||
def account() -> Account:
|
||||
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
|
||||
acct.set_original_state()
|
||||
return acct
|
||||
|
||||
|
||||
def test_original_state(account):
|
||||
"""Test the original state - see if it resets properly"""
|
||||
account.log_on()
|
||||
account.log_off()
|
||||
account.disable()
|
||||
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is False
|
||||
|
||||
account.reset_component_for_episode(episode=1)
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 0
|
||||
assert state["num_logoffs"] is 0
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is True
|
||||
|
||||
account.log_on()
|
||||
account.log_off()
|
||||
account.disable()
|
||||
account.set_original_state()
|
||||
|
||||
account.log_on()
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 2
|
||||
|
||||
account.reset_component_for_episode(episode=2)
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is False
|
||||
|
||||
|
||||
def test_enable(account):
|
||||
"""Should enable the account."""
|
||||
account.enabled = False
|
||||
account.enable()
|
||||
assert account.enabled is True
|
||||
|
||||
|
||||
def test_disable(account):
|
||||
"""Should disable the account."""
|
||||
account.enabled = True
|
||||
account.disable()
|
||||
assert account.enabled is False
|
||||
|
||||
|
||||
def test_log_on_increments(account):
|
||||
"""Should increase the log on value by 1."""
|
||||
account.num_logons = 0
|
||||
account.log_on()
|
||||
assert account.num_logons is 1
|
||||
|
||||
|
||||
def test_log_off_increments(account):
|
||||
"""Should increase the log on value by 1."""
|
||||
account.num_logoffs = 0
|
||||
account.log_off()
|
||||
assert account.num_logoffs is 1
|
||||
|
||||
|
||||
def test_account_serialise(account):
|
||||
"""Test that an account can be serialised. If pydantic throws error then this test fails."""
|
||||
acct = Account(username="Jake", password="JakePass1!", account_type=AccountType.USER)
|
||||
serialised = acct.model_dump_json()
|
||||
serialised = account.model_dump_json()
|
||||
print(serialised)
|
||||
|
||||
|
||||
def test_account_deserialise():
|
||||
def test_account_deserialise(account):
|
||||
"""Test that an account can be deserialised. The test fails if pydantic throws an error."""
|
||||
acct_json = (
|
||||
'{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,'
|
||||
'"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"request_manager":null}'
|
||||
'"username":"Jake","password":"totally_hashed_password","account_type":2,"status":2,"request_manager":null}'
|
||||
)
|
||||
acct = Account.model_validate_json(acct_json)
|
||||
assert Account.model_validate_json(acct_json)
|
||||
|
||||
|
||||
def test_describe_state(account):
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 0
|
||||
assert state["num_logoffs"] is 0
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is True
|
||||
|
||||
account.log_on()
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 0
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is True
|
||||
|
||||
account.log_off()
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is True
|
||||
|
||||
account.disable()
|
||||
state = account.describe_state()
|
||||
assert state["num_logons"] is 1
|
||||
assert state["num_logoffs"] is 1
|
||||
assert state["num_group_changes"] is 0
|
||||
assert state["username"] is "Jake"
|
||||
assert state["password"] is "totally_hashed_password"
|
||||
assert state["account_type"] is AccountType.USER.value
|
||||
assert state["enabled"] is False
|
||||
|
||||
@@ -185,6 +185,38 @@ def test_get_file(file_system):
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
def test_reset_file_system(file_system):
|
||||
# file and folder that existed originally
|
||||
file_system.create_file(file_name="test_file.zip")
|
||||
file_system.create_folder(folder_name="test_folder")
|
||||
file_system.set_original_state()
|
||||
|
||||
# create a new file
|
||||
file_system.create_file(file_name="new_file.txt")
|
||||
|
||||
# create a new folder
|
||||
file_system.create_folder(folder_name="new_folder")
|
||||
|
||||
# delete the file that existed originally
|
||||
file_system.delete_file(folder_name="root", file_name="test_file.zip")
|
||||
assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
|
||||
|
||||
# delete the folder that existed originally
|
||||
file_system.delete_folder(folder_name="test_folder")
|
||||
assert file_system.get_folder(folder_name="test_folder") is None
|
||||
|
||||
# reset
|
||||
file_system.reset_component_for_episode(episode=1)
|
||||
|
||||
# deleted original file and folder should be back
|
||||
assert file_system.get_file(folder_name="root", file_name="test_file.zip")
|
||||
assert file_system.get_folder(folder_name="test_folder")
|
||||
|
||||
# new file and folder should be removed
|
||||
assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
|
||||
assert file_system.get_folder(folder_name="new_folder") is None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
|
||||
def test_serialisation(file_system):
|
||||
"""Test to check that the object serialisation works correctly."""
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def switch() -> Switch:
|
||||
switch: Switch = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
|
||||
switch.show()
|
||||
return switch
|
||||
|
||||
|
||||
def test_describe_state(switch):
|
||||
state = switch.describe_state()
|
||||
assert len(state.get("ports")) is 8
|
||||
assert state.get("num_ports") is 8
|
||||
@@ -3,6 +3,66 @@ import json
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def network(example_network) -> Network:
|
||||
assert len(example_network.routers) is 1
|
||||
assert len(example_network.switches) is 2
|
||||
assert len(example_network.computers) is 2
|
||||
assert len(example_network.servers) is 2
|
||||
|
||||
example_network.set_original_state()
|
||||
example_network.show()
|
||||
|
||||
return example_network
|
||||
|
||||
|
||||
def test_describe_state(network):
|
||||
"""Test that describe state works."""
|
||||
state = network.describe_state()
|
||||
|
||||
assert len(state["nodes"]) is 7
|
||||
assert len(state["links"]) is 6
|
||||
|
||||
|
||||
def test_reset_network(network):
|
||||
"""
|
||||
Test that the network is properly reset.
|
||||
|
||||
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
|
||||
etc are also removed/reinstalled
|
||||
|
||||
"""
|
||||
state_before = network.describe_state()
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
server_1: Computer = network.get_node_by_hostname("server_1")
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
client_1.power_off()
|
||||
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
server_1.power_off()
|
||||
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
|
||||
|
||||
assert network.describe_state() is not state_before
|
||||
|
||||
network.reset_component_for_episode(episode=1)
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
assert server_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
|
||||
state_before, sort_keys=True, indent=2
|
||||
)
|
||||
|
||||
|
||||
def test_creating_container():
|
||||
@@ -10,11 +70,46 @@ def test_creating_container():
|
||||
net = Network()
|
||||
assert net.nodes == {}
|
||||
assert net.links == {}
|
||||
net.show()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
|
||||
def test_describe_state():
|
||||
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable."""
|
||||
net = Network()
|
||||
state = net.describe_state()
|
||||
json.dumps(state) # if this function call raises an error, the test fails, state was not JSON-serialisable
|
||||
def test_apply_timestep_to_nodes(network):
|
||||
"""Calling apply_timestep on the network should apply to the nodes within it."""
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
assert client_1.operating_state is NodeOperatingState.ON
|
||||
|
||||
client_1.power_off()
|
||||
|
||||
for i in range(client_1.shut_down_duration + 1):
|
||||
network.apply_timestep(timestep=i)
|
||||
|
||||
assert client_1.operating_state is NodeOperatingState.OFF
|
||||
|
||||
|
||||
def test_removing_node_that_does_not_exist(network):
|
||||
"""Node that does not exist on network should not affect existing nodes."""
|
||||
assert len(network.nodes) is 7
|
||||
|
||||
network.remove_node(Node(hostname="new_node"))
|
||||
assert len(network.nodes) is 7
|
||||
|
||||
|
||||
def test_remove_node(network):
|
||||
"""Remove node should remove the correct node."""
|
||||
assert len(network.nodes) is 7
|
||||
|
||||
client_1: Computer = network.get_node_by_hostname("client_1")
|
||||
network.remove_node(client_1)
|
||||
|
||||
assert network.get_node_by_hostname("client_1") is None
|
||||
assert len(network.nodes) is 6
|
||||
|
||||
|
||||
def test_remove_link(network):
|
||||
"""Remove link should remove the correct link."""
|
||||
assert len(network.links) is 6
|
||||
link: Link = network.links.get(next(iter(network.links)))
|
||||
|
||||
network.remove_link(link)
|
||||
assert len(network.links) is 5
|
||||
assert network.links.get(link.uuid) is None
|
||||
|
||||
11
tests/unit_tests/_primaite/_simulator/_network/test_utils.py
Normal file
11
tests/unit_tests/_primaite/_simulator/_network/test_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from primaite.simulator.network.utils import convert_bytes_to_megabits, convert_megabits_to_bytes
|
||||
|
||||
|
||||
def test_convert_bytes_to_megabits():
|
||||
assert round(convert_bytes_to_megabits(B=131072), 5) == float(1)
|
||||
assert round(convert_bytes_to_megabits(B=69420), 5) == float(0.52963)
|
||||
|
||||
|
||||
def test_convert_megabits_to_bytes():
|
||||
assert round(convert_megabits_to_bytes(Mbits=1), 5) == float(131072)
|
||||
assert round(convert_megabits_to_bytes(Mbits=float(0.52963)), 5) == float(69419.66336)
|
||||
@@ -0,0 +1,122 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def database_client_on_computer() -> Tuple[DatabaseClient, Computer]:
|
||||
computer = Computer(
|
||||
hostname="db_node", ip_address="192.168.0.1", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
|
||||
)
|
||||
computer.software_manager.install(DatabaseClient)
|
||||
|
||||
database_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient")
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
|
||||
database_client.run()
|
||||
return database_client, computer
|
||||
|
||||
|
||||
def test_creation(database_client_on_computer):
|
||||
database_client, computer = database_client_on_computer
|
||||
database_client.describe_state()
|
||||
|
||||
|
||||
def test_connect_when_client_is_closed(database_client_on_computer):
|
||||
"""Database client should not connect when it is not running."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.close()
|
||||
assert database_client.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
assert database_client.connect() is False
|
||||
|
||||
|
||||
def test_connect_to_database_fails_on_reattempt(database_client_on_computer):
|
||||
"""Database client should return False when the attempt to connect fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.connected = False
|
||||
assert database_client._connect(server_ip_address=IPv4Address("192.168.0.1"), is_reattempt=True) is False
|
||||
|
||||
|
||||
def test_disconnect_when_client_is_closed(database_client_on_computer):
|
||||
"""Database client disconnect should not do anything when it is not running."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.connected = True
|
||||
assert database_client.server_ip_address is not None
|
||||
|
||||
database_client.close()
|
||||
assert database_client.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
database_client.disconnect()
|
||||
|
||||
assert database_client.connected is True
|
||||
assert database_client.server_ip_address is not None
|
||||
|
||||
|
||||
def test_disconnect(database_client_on_computer):
|
||||
"""Database client should set connected to False and remove the database server ip address."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.connected = True
|
||||
|
||||
assert database_client.operating_state is ApplicationOperatingState.RUNNING
|
||||
assert database_client.server_ip_address is not None
|
||||
|
||||
database_client.disconnect()
|
||||
|
||||
assert database_client.connected is False
|
||||
assert database_client.server_ip_address is None
|
||||
|
||||
|
||||
def test_query_when_client_is_closed(database_client_on_computer):
|
||||
"""Database client should return False when it is not running."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.close()
|
||||
assert database_client.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
assert database_client.query(sql="test") is False
|
||||
|
||||
|
||||
def test_query_failed_reattempt(database_client_on_computer):
|
||||
"""Database client query should return False if the reattempt fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
def return_false():
|
||||
return False
|
||||
|
||||
database_client.connect = return_false
|
||||
|
||||
database_client.connected = False
|
||||
assert database_client.query(sql="test", is_reattempt=True) is False
|
||||
|
||||
|
||||
def test_query_fail_to_connect(database_client_on_computer):
|
||||
"""Database client query should return False if the connect attempt fails."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
def return_false():
|
||||
return False
|
||||
|
||||
database_client.connect = return_false
|
||||
database_client.connected = False
|
||||
|
||||
assert database_client.query(sql="test") is False
|
||||
|
||||
|
||||
def test_client_receives_response_when_closed(database_client_on_computer):
|
||||
"""Database client receive should return False when it is closed."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.close()
|
||||
assert database_client.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
database_client.receive(payload={}, session_id="")
|
||||
@@ -1,39 +1,66 @@
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def web_client() -> Computer:
|
||||
node = Computer(
|
||||
hostname="web_client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
def web_browser() -> WebBrowser:
|
||||
computer = Computer(
|
||||
hostname="web_client",
|
||||
ip_address="192.168.1.11",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
return node
|
||||
# Web Browser should be pre-installed in computer
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
web_browser.run()
|
||||
assert web_browser.operating_state is ApplicationOperatingState.RUNNING
|
||||
return web_browser
|
||||
|
||||
|
||||
def test_create_web_client(web_client):
|
||||
assert web_client is not None
|
||||
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
|
||||
def test_create_web_client():
|
||||
computer = Computer(
|
||||
hostname="web_client",
|
||||
ip_address="192.168.1.11",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
# Web Browser should be pre-installed in computer
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
assert web_browser.name is "WebBrowser"
|
||||
assert web_browser.port is Port.HTTP
|
||||
assert web_browser.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_receive_invalid_payload(web_client):
|
||||
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
|
||||
|
||||
def test_receive_invalid_payload(web_browser):
|
||||
assert web_browser.receive(payload={}) is False
|
||||
|
||||
|
||||
def test_receive_payload(web_client):
|
||||
def test_receive_payload(web_browser):
|
||||
payload = HttpResponsePacket(status_code=HttpStatusCode.OK)
|
||||
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
|
||||
assert web_browser.latest_response is None
|
||||
|
||||
web_browser.receive(payload=payload)
|
||||
|
||||
assert web_browser.latest_response is not None
|
||||
|
||||
|
||||
def test_invalid_target_url(web_browser):
|
||||
# none value target url
|
||||
web_browser.target_url = None
|
||||
assert web_browser.get_webpage() is False
|
||||
|
||||
|
||||
def test_non_existent_target_url(web_browser):
|
||||
web_browser.target_url = "http://192.168.255.255"
|
||||
assert web_browser.get_webpage() is False
|
||||
|
||||
@@ -19,11 +19,11 @@ def dm_client() -> Node:
|
||||
|
||||
@pytest.fixture
|
||||
def dm_bot(dm_client) -> DataManipulationBot:
|
||||
return dm_client.software_manager.software["DataManipulationBot"]
|
||||
return dm_client.software_manager.software.get("DataManipulationBot")
|
||||
|
||||
|
||||
def test_create_dm_bot(dm_client):
|
||||
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"]
|
||||
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot")
|
||||
|
||||
assert data_manipulation_bot.name == "DataManipulationBot"
|
||||
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
|
||||
|
||||
@@ -8,7 +8,7 @@ from primaite.simulator.system.services.database.database_service import Databas
|
||||
def database_server() -> Node:
|
||||
node = Node(hostname="db_node")
|
||||
node.software_manager.install(DatabaseService)
|
||||
node.software_manager.software["DatabaseService"].start()
|
||||
node.software_manager.software.get("DatabaseService").start()
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@@ -5,28 +5,13 @@ import pytest
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_server() -> Node:
|
||||
node = Server(
|
||||
hostname="dns_server",
|
||||
ip_address="192.168.1.10",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
node.software_manager.install(software_class=DNSServer)
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_client() -> Node:
|
||||
node = Computer(
|
||||
@@ -39,24 +24,16 @@ def dns_client() -> Node:
|
||||
return node
|
||||
|
||||
|
||||
def test_create_dns_server(dns_server):
|
||||
assert dns_server is not None
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
|
||||
assert dns_server_service.name is "DNSServer"
|
||||
assert dns_server_service.port is Port.DNS
|
||||
assert dns_server_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_create_dns_client(dns_client):
|
||||
assert dns_client is not None
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
|
||||
assert dns_client_service.name is "DNSClient"
|
||||
assert dns_client_service.port is Port.DNS
|
||||
assert dns_client_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
|
||||
assert dns_client.operating_state is NodeOperatingState.OFF
|
||||
assert dns_client_service.operating_state is ServiceOperatingState.STOPPED
|
||||
|
||||
@@ -69,7 +46,7 @@ def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
|
||||
|
||||
def test_dns_client_check_domain_exists_when_not_running(dns_client):
|
||||
dns_client.operating_state = NodeOperatingState.ON
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
|
||||
dns_client_service.start()
|
||||
|
||||
assert dns_client.operating_state is NodeOperatingState.ON
|
||||
@@ -93,22 +70,10 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client):
|
||||
assert dns_client_service.check_domain_exists("test.com") is False
|
||||
|
||||
|
||||
def test_dns_server_domain_name_registration(dns_server):
|
||||
"""Test to check if the domain name registration works."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
# return none for an unknown domain
|
||||
assert dns_server_service.dns_lookup("fake-domain.com") is None
|
||||
assert dns_server_service.dns_lookup("real-domain.com") is not None
|
||||
|
||||
|
||||
def test_dns_client_check_domain_in_cache(dns_client):
|
||||
"""Test to make sure that the check_domain_in_cache returns the correct values."""
|
||||
dns_client.operating_state = NodeOperatingState.ON
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
|
||||
dns_client_service.start()
|
||||
|
||||
# add a domain to the dns client cache
|
||||
@@ -118,29 +83,9 @@ def test_dns_client_check_domain_in_cache(dns_client):
|
||||
assert dns_client_service.check_domain_exists("real-domain.com") is True
|
||||
|
||||
|
||||
def test_dns_server_receive(dns_server):
|
||||
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
assert (
|
||||
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
|
||||
is False
|
||||
)
|
||||
|
||||
assert (
|
||||
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
|
||||
is True
|
||||
)
|
||||
|
||||
dns_server_service.show()
|
||||
|
||||
|
||||
def test_dns_client_receive(dns_client):
|
||||
"""Test to make sure the DNS Client knows how to deal with request responses."""
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
|
||||
|
||||
dns_client_service.receive(
|
||||
payload=DNSPacket(
|
||||
@@ -151,3 +96,9 @@ def test_dns_client_receive(dns_client):
|
||||
|
||||
# domain name should be saved to cache
|
||||
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")
|
||||
|
||||
|
||||
def test_dns_client_receive_non_dns_payload(dns_client):
|
||||
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
|
||||
|
||||
assert dns_client_service.receive(payload=None) is False
|
||||
@@ -0,0 +1,64 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def dns_server() -> Node:
|
||||
node = Server(
|
||||
hostname="dns_server",
|
||||
ip_address="192.168.1.10",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
node.software_manager.install(software_class=DNSServer)
|
||||
return node
|
||||
|
||||
|
||||
def test_create_dns_server(dns_server):
|
||||
assert dns_server is not None
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
|
||||
assert dns_server_service.name is "DNSServer"
|
||||
assert dns_server_service.port is Port.DNS
|
||||
assert dns_server_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_dns_server_domain_name_registration(dns_server):
|
||||
"""Test to check if the domain name registration works."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
# return none for an unknown domain
|
||||
assert dns_server_service.dns_lookup("fake-domain.com") is None
|
||||
assert dns_server_service.dns_lookup("real-domain.com") is not None
|
||||
|
||||
|
||||
def test_dns_server_receive(dns_server):
|
||||
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
|
||||
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
|
||||
|
||||
# register the web server in the domain controller
|
||||
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
|
||||
|
||||
assert (
|
||||
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="fake-domain.com")))
|
||||
is False
|
||||
)
|
||||
|
||||
assert (
|
||||
dns_server_service.receive(payload=DNSPacket(dns_request=DNSRequest(domain_name_request="real-domain.com")))
|
||||
is True
|
||||
)
|
||||
|
||||
dns_server_service.show()
|
||||
@@ -0,0 +1,122 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ftp_client() -> Node:
|
||||
node = Computer(
|
||||
hostname="ftp_client",
|
||||
ip_address="192.168.1.11",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
def test_create_ftp_client(ftp_client):
|
||||
assert ftp_client is not None
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
assert ftp_client_service.name is "FTPClient"
|
||||
assert ftp_client_service.port is Port.FTP
|
||||
assert ftp_client_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_ftp_client_store_file(ftp_client):
|
||||
"""Test to make sure the FTP Client knows how to deal with request responses."""
|
||||
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
|
||||
|
||||
response: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.STOR,
|
||||
ftp_command_args={
|
||||
"dest_folder_name": "downloads",
|
||||
"dest_file_name": "file.txt",
|
||||
"file_size": 24,
|
||||
},
|
||||
packet_payload_size=24,
|
||||
status_code=FTPStatusCode.OK,
|
||||
)
|
||||
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
ftp_client_service.receive(response)
|
||||
|
||||
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt")
|
||||
|
||||
|
||||
def test_ftp_should_not_process_commands_if_service_not_running(ftp_client):
|
||||
"""Method _process_ftp_command should return false if service is not running."""
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.PORT,
|
||||
ftp_command_args=Port.FTP,
|
||||
status_code=FTPStatusCode.OK,
|
||||
)
|
||||
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
ftp_client_service.stop()
|
||||
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
|
||||
assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
|
||||
|
||||
|
||||
def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
|
||||
"""Method send_file should return false if no file to send."""
|
||||
assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None
|
||||
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
assert ftp_client_service.operating_state is ServiceOperatingState.RUNNING
|
||||
assert (
|
||||
ftp_client_service.send_file(
|
||||
dest_ip_address=IPv4Address("192.168.1.1"),
|
||||
src_folder_name="root",
|
||||
src_file_name="test.txt",
|
||||
dest_folder_name="root",
|
||||
dest_file_name="text.txt",
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
def test_offline_ftp_client_receives_request(ftp_client):
|
||||
"""Receive should return false if the node the ftp client is installed on is offline."""
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
ftp_client.power_off()
|
||||
|
||||
for i in range(ftp_client.shut_down_duration + 1):
|
||||
ftp_client.apply_timestep(timestep=i)
|
||||
|
||||
assert ftp_client.operating_state is NodeOperatingState.OFF
|
||||
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
|
||||
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.PORT,
|
||||
ftp_command_args=Port.FTP,
|
||||
status_code=FTPStatusCode.OK,
|
||||
)
|
||||
|
||||
assert ftp_client_service.receive(payload=payload) is False
|
||||
|
||||
|
||||
def test_receive_should_fail_if_payload_is_not_ftp(ftp_client):
|
||||
"""Receive should return false if the node the ftp client is installed on is not an FTPPacket."""
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
assert ftp_client_service.receive(payload=None) is False
|
||||
|
||||
|
||||
def test_receive_should_ignore_payload_with_none_status_code(ftp_client):
|
||||
"""Receive should ignore payload with no set status code to prevent infinite send/receive loops."""
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.PORT,
|
||||
ftp_command_args=Port.FTP,
|
||||
status_code=None,
|
||||
)
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
|
||||
assert ftp_client_service.receive(payload=payload) is False
|
||||
@@ -1,16 +1,13 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -26,34 +23,14 @@ def ftp_server() -> Node:
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ftp_client() -> Node:
|
||||
node = Computer(
|
||||
hostname="ftp_client",
|
||||
ip_address="192.168.1.11",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
def test_create_ftp_server(ftp_server):
|
||||
assert ftp_server is not None
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
|
||||
assert ftp_server_service.name is "FTPServer"
|
||||
assert ftp_server_service.port is Port.FTP
|
||||
assert ftp_server_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_create_ftp_client(ftp_client):
|
||||
assert ftp_client is not None
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
|
||||
assert ftp_client_service.name is "FTPClient"
|
||||
assert ftp_client_service.port is Port.FTP
|
||||
assert ftp_client_service.protocol is IPProtocol.TCP
|
||||
|
||||
|
||||
def test_ftp_server_store_file(ftp_server):
|
||||
"""Test to make sure the FTP Server knows how to deal with request responses."""
|
||||
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
|
||||
@@ -68,16 +45,34 @@ def test_ftp_server_store_file(ftp_server):
|
||||
packet_payload_size=24,
|
||||
)
|
||||
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
|
||||
ftp_server_service.receive(response)
|
||||
|
||||
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt")
|
||||
|
||||
|
||||
def test_ftp_client_store_file(ftp_client):
|
||||
"""Test to make sure the FTP Client knows how to deal with request responses."""
|
||||
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
|
||||
def test_ftp_server_should_send_error_if_port_arg_is_invalid(ftp_server):
|
||||
"""Should fail if the port command receives an invalid port."""
|
||||
payload: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.PORT,
|
||||
ftp_command_args=None,
|
||||
packet_payload_size=24,
|
||||
)
|
||||
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
|
||||
assert ftp_server_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
|
||||
|
||||
|
||||
def test_ftp_server_receives_non_ftp_packet(ftp_server):
|
||||
"""Receive should return false if the service receives a non ftp packet."""
|
||||
response: FTPPacket = None
|
||||
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
|
||||
assert ftp_server_service.receive(response) is False
|
||||
|
||||
|
||||
def test_offline_ftp_server_receives_request(ftp_server):
|
||||
"""Receive should return false if the service is stopped."""
|
||||
response: FTPPacket = FTPPacket(
|
||||
ftp_command=FTPCommand.STOR,
|
||||
ftp_command_args={
|
||||
@@ -86,10 +81,9 @@ def test_ftp_client_store_file(ftp_client):
|
||||
"file_size": 24,
|
||||
},
|
||||
packet_payload_size=24,
|
||||
status_code=FTPStatusCode.OK,
|
||||
)
|
||||
|
||||
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
|
||||
ftp_client_service.receive(response)
|
||||
|
||||
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt")
|
||||
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
|
||||
ftp_server_service.stop()
|
||||
assert ftp_server_service.operating_state is ServiceOperatingState.STOPPED
|
||||
assert ftp_server_service.receive(response) is False
|
||||
@@ -18,13 +18,13 @@ def web_server() -> Server:
|
||||
hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
|
||||
)
|
||||
node.software_manager.install(software_class=WebServer)
|
||||
node.software_manager.software["WebServer"].start()
|
||||
node.software_manager.software.get("WebServer").start()
|
||||
return node
|
||||
|
||||
|
||||
def test_create_web_server(web_server):
|
||||
assert web_server is not None
|
||||
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
|
||||
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
|
||||
assert web_server_service.name is "WebServer"
|
||||
assert web_server_service.port is Port.HTTP
|
||||
assert web_server_service.protocol is IPProtocol.TCP
|
||||
@@ -33,7 +33,7 @@ def test_create_web_server(web_server):
|
||||
def test_handling_get_request_not_found_path(web_server):
|
||||
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path")
|
||||
|
||||
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
|
||||
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
|
||||
|
||||
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
|
||||
assert response.status_code == HttpStatusCode.NOT_FOUND
|
||||
@@ -42,7 +42,7 @@ def test_handling_get_request_not_found_path(web_server):
|
||||
def test_handling_get_request_home_page(web_server):
|
||||
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
|
||||
|
||||
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
|
||||
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
|
||||
|
||||
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
|
||||
assert response.status_code == HttpStatusCode.OK
|
||||
@@ -51,7 +51,7 @@ def test_handling_get_request_home_page(web_server):
|
||||
def test_process_http_request_get(web_server):
|
||||
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
|
||||
|
||||
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
|
||||
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
|
||||
|
||||
assert web_server_service._process_http_request(payload=payload) is True
|
||||
|
||||
@@ -59,6 +59,6 @@ def test_process_http_request_get(web_server):
|
||||
def test_process_http_request_method_not_allowed(web_server):
|
||||
payload = HttpRequestPacket(request_method=HttpRequestMethod.DELETE, request_url="http://domain.com/")
|
||||
|
||||
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
|
||||
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
|
||||
|
||||
assert web_server_service._process_http_request(payload=payload) is False
|
||||
|
||||
Reference in New Issue
Block a user