#2084: change all instances of retrieving software from software['software_name'] to software.get() + adding some tests for describe state

This commit is contained in:
Czar Echavez
2023-11-30 13:49:37 +00:00
parent 7c1ffb5ba1
commit 3cf21e4015
30 changed files with 394 additions and 97 deletions

View File

@@ -1,18 +1,140 @@
"""Test the account module of the simulator."""
import pytest
from primaite.simulator.domain.account import Account, AccountType
def test_account_serialise():
@pytest.fixture(scope="function")
def account() -> Account:
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
acct.set_original_state()
return acct
def test_original_state(account):
"""Test the original state - see if it resets properly"""
account.log_on()
account.log_off()
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
account.reset_component_for_episode(episode=1)
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_on()
account.log_off()
account.disable()
account.set_original_state()
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 2
account.reset_component_for_episode(episode=2)
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
def test_enable(account):
"""Should enable the account."""
account.enabled = False
account.enable()
assert account.enabled is True
def test_disable(account):
"""Should disable the account."""
account.enabled = True
account.disable()
assert account.enabled is False
def test_log_on_increments(account):
"""Should increase the log on value by 1."""
account.num_logons = 0
account.log_on()
assert account.num_logons is 1
def test_log_off_increments(account):
"""Should increase the log on value by 1."""
account.num_logoffs = 0
account.log_off()
assert account.num_logoffs is 1
def test_account_serialise(account):
"""Test that an account can be serialised. If pydantic throws error then this test fails."""
acct = Account(username="Jake", password="JakePass1!", account_type=AccountType.USER)
serialised = acct.model_dump_json()
serialised = account.model_dump_json()
print(serialised)
def test_account_deserialise():
def test_account_deserialise(account):
"""Test that an account can be deserialised. The test fails if pydantic throws an error."""
acct_json = (
'{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,'
'"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"request_manager":null}'
'"username":"Jake","password":"totally_hashed_password","account_type":2,"status":2,"request_manager":null}'
)
acct = Account.model_validate_json(acct_json)
assert Account.model_validate_json(acct_json)
def test_describe_state(account):
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_off()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False

View File

@@ -3,6 +3,64 @@ import json
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
@pytest.fixture(scope="function")
def network(example_network) -> Network:
assert len(example_network.routers) is 1
assert len(example_network.switches) is 2
assert len(example_network.computers) is 2
assert len(example_network.servers) is 2
example_network.set_original_state()
return example_network
def test_describe_state(example_network):
"""Test that describe state works."""
state = example_network.describe_state()
assert len(state["nodes"]) is 7
assert len(state["links"]) is 6
def test_reset_network(example_network):
"""
Test that the network is properly reset.
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
etc are also removed/reinstalled
"""
state_before = example_network.describe_state()
client_1: Computer = example_network.get_node_by_hostname("client_1")
server_1: Computer = example_network.get_node_by_hostname("server_1")
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
client_1.power_off()
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
server_1.power_off()
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
assert example_network.describe_state() is not state_before
example_network.reset_component_for_episode(episode=1)
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
assert json.dumps(example_network.describe_state(), sort_keys=True, indent=2) == json.dumps(
state_before, sort_keys=True, indent=2
)
def test_creating_container():
@@ -10,11 +68,3 @@ def test_creating_container():
net = Network()
assert net.nodes == {}
assert net.links == {}
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_describe_state():
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable."""
net = Network()
state = net.describe_state()
json.dumps(state) # if this function call raises an error, the test fails, state was not JSON-serialisable

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
from typing import Tuple, Union
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
@pytest.fixture(scope="function")
def database_client_on_computer() -> Tuple[DatabaseClient, Computer]:
computer = Computer(
hostname="db_node", ip_address="192.168.0.1", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
)
computer.software_manager.install(DatabaseClient)
database_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
database_client.run()
return database_client, computer
def test_creation(database_client_on_computer):
database_client, computer = database_client_on_computer
database_client.describe_state()
def test_connect_when_client_is_closed(database_client_on_computer):
"""Database client should not connect when it is not running."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
assert database_client.connect() is False
def test_connect_to_database_fails_on_reattempt(database_client_on_computer):
"""Database client should return False when the attempt to connect fails."""
database_client, computer = database_client_on_computer
database_client.connected = False
assert database_client._connect(server_ip_address=IPv4Address("192.168.0.1"), is_reattempt=True) is False
def test_disconnect_when_client_is_closed(database_client_on_computer):
"""Database client disconnect should not do anything when it is not running."""
database_client, computer = database_client_on_computer
database_client.connected = True
assert database_client.server_ip_address is not None
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
database_client.disconnect()
assert database_client.connected is True
assert database_client.server_ip_address is not None
def test_disconnect(database_client_on_computer):
"""Database client should set connected to False and remove the database server ip address."""
database_client, computer = database_client_on_computer
database_client.connected = True
assert database_client.operating_state is ApplicationOperatingState.RUNNING
assert database_client.server_ip_address is not None
database_client.disconnect()
assert database_client.connected is False
assert database_client.server_ip_address is None
def test_query_when_client_is_closed(database_client_on_computer):
"""Database client should return False when it is not running."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
assert database_client.query(sql="test") is False
def test_query_failed_reattempt(database_client_on_computer):
"""Database client query should return False if the reattempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test", is_reattempt=True) is False
def test_query_fail_to_connect(database_client_on_computer):
"""Database client query should return False if the connect attempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test") is False
def test_client_receives_response_when_closed(database_client_on_computer):
"""Database client receive should return False when it is closed."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
database_client.receive(payload={}, session_id="")

View File

@@ -21,7 +21,7 @@ def web_browser() -> WebBrowser:
operating_state=NodeOperatingState.ON,
)
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software["WebBrowser"]
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.run()
assert web_browser.operating_state is ApplicationOperatingState.RUNNING
return web_browser
@@ -36,7 +36,7 @@ def test_create_web_client():
operating_state=NodeOperatingState.ON,
)
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software["WebBrowser"]
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
assert web_browser.name is "WebBrowser"
assert web_browser.port is Port.HTTP
assert web_browser.protocol is IPProtocol.TCP

View File

@@ -19,11 +19,11 @@ def dm_client() -> Node:
@pytest.fixture
def dm_bot(dm_client) -> DataManipulationBot:
return dm_client.software_manager.software["DataManipulationBot"]
return dm_client.software_manager.software.get("DataManipulationBot")
def test_create_dm_bot(dm_client):
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software["DataManipulationBot"]
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot")
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER

View File

@@ -8,7 +8,7 @@ from primaite.simulator.system.services.database.database_service import Databas
def database_server() -> Node:
node = Node(hostname="db_node")
node.software_manager.install(DatabaseService)
node.software_manager.software["DatabaseService"].start()
node.software_manager.software.get("DatabaseService").start()
return node

View File

@@ -26,14 +26,14 @@ def dns_client() -> Node:
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.TCP
def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client.operating_state is NodeOperatingState.OFF
assert dns_client_service.operating_state is ServiceOperatingState.STOPPED
@@ -46,7 +46,7 @@ def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
def test_dns_client_check_domain_exists_when_not_running(dns_client):
dns_client.operating_state = NodeOperatingState.ON
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.start()
assert dns_client.operating_state is NodeOperatingState.ON
@@ -73,7 +73,7 @@ def test_dns_client_check_domain_exists_when_not_running(dns_client):
def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client.operating_state = NodeOperatingState.ON
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.start()
# add a domain to the dns client cache
@@ -85,7 +85,7 @@ def test_dns_client_check_domain_in_cache(dns_client):
def test_dns_client_receive(dns_client):
"""Test to make sure the DNS Client knows how to deal with request responses."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.receive(
payload=DNSPacket(
@@ -99,6 +99,6 @@ def test_dns_client_receive(dns_client):
def test_dns_client_receive_non_dns_payload(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client_service.receive(payload=None) is False

View File

@@ -26,7 +26,7 @@ def dns_server() -> Node:
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.TCP
@@ -34,7 +34,7 @@ def test_create_dns_server(dns_server):
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
@@ -46,7 +46,7 @@ def test_dns_server_domain_name_registration(dns_server):
def test_dns_server_receive(dns_server):
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))

View File

@@ -26,7 +26,7 @@ def ftp_client() -> Node:
def test_create_ftp_client(ftp_client):
assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.name is "FTPClient"
assert ftp_client_service.port is Port.FTP
assert ftp_client_service.protocol is IPProtocol.TCP
@@ -47,7 +47,7 @@ def test_ftp_client_store_file(ftp_client):
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client_service.receive(response)
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt")
@@ -61,7 +61,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client):
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client_service.stop()
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
@@ -71,7 +71,7 @@ def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
"""Method send_file should return false if no file to send."""
assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.operating_state is ServiceOperatingState.RUNNING
assert (
ftp_client_service.send_file(
@@ -87,7 +87,7 @@ def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
def test_offline_ftp_client_receives_request(ftp_client):
"""Receive should return false if the node the ftp client is installed on is offline."""
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client.power_off()
for i in range(ftp_client.shut_down_duration + 1):
@@ -107,7 +107,7 @@ def test_offline_ftp_client_receives_request(ftp_client):
def test_receive_should_fail_if_payload_is_not_ftp(ftp_client):
"""Receive should return false if the node the ftp client is installed on is not an FTPPacket."""
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.receive(payload=None) is False
@@ -118,5 +118,5 @@ def test_receive_should_ignore_payload_with_none_status_code(ftp_client):
ftp_command_args=Port.FTP,
status_code=None,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.receive(payload=payload) is False

View File

@@ -25,7 +25,7 @@ def ftp_server() -> Node:
def test_create_ftp_server(ftp_server):
assert ftp_server is not None
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service.name is "FTPServer"
assert ftp_server_service.port is Port.FTP
assert ftp_server_service.protocol is IPProtocol.TCP
@@ -45,7 +45,7 @@ def test_ftp_server_store_file(ftp_server):
packet_payload_size=24,
)
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
ftp_server_service.receive(response)
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt")
@@ -59,7 +59,7 @@ def test_ftp_server_should_send_error_if_port_arg_is_invalid(ftp_server):
packet_payload_size=24,
)
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
@@ -67,7 +67,7 @@ def test_ftp_server_receives_non_ftp_packet(ftp_server):
"""Receive should return false if the service receives a non ftp packet."""
response: FTPPacket = None
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service.receive(response) is False
@@ -83,7 +83,7 @@ def test_offline_ftp_server_receives_request(ftp_server):
packet_payload_size=24,
)
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
ftp_server_service.stop()
assert ftp_server_service.operating_state is ServiceOperatingState.STOPPED
assert ftp_server_service.receive(response) is False

View File

@@ -18,13 +18,13 @@ def web_server() -> Server:
hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
node.software_manager.install(software_class=WebServer)
node.software_manager.software["WebServer"].start()
node.software_manager.software.get("WebServer").start()
return node
def test_create_web_server(web_server):
assert web_server is not None
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service.name is "WebServer"
assert web_server_service.port is Port.HTTP
assert web_server_service.protocol is IPProtocol.TCP
@@ -33,7 +33,7 @@ def test_create_web_server(web_server):
def test_handling_get_request_not_found_path(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
assert response.status_code == HttpStatusCode.NOT_FOUND
@@ -42,7 +42,7 @@ def test_handling_get_request_not_found_path(web_server):
def test_handling_get_request_home_page(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
assert response.status_code == HttpStatusCode.OK
@@ -51,7 +51,7 @@ def test_handling_get_request_home_page(web_server):
def test_process_http_request_get(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service._process_http_request(payload=payload) is True
@@ -59,6 +59,6 @@ def test_process_http_request_get(web_server):
def test_process_http_request_method_not_allowed(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.DELETE, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service._process_http_request(payload=payload) is False