diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 812fffa8..c03bca36 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,9 +17,11 @@ from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot +from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.services.database.database_service import DatabaseService from primaite.simulator.system.services.dns.dns_client import DNSClient @@ -32,6 +34,24 @@ from primaite.simulator.system.services.web_server.web_server import WebServer _LOGGER = getLogger(__name__) +APPLICATION_TYPES_MAPPING = { + "WebBrowser": WebBrowser, + "DatabaseClient": DatabaseClient, + "DataManipulationBot": DataManipulationBot, + "DoSBot": DoSBot, +} + +SERVICE_TYPES_MAPPING = { + "DNSClient": DNSClient, + "DNSServer": DNSServer, + "DatabaseService": DatabaseService, + "WebServer": WebServer, + "FTPClient": FTPClient, + "FTPServer": FTPServer, + "NTPClient": NTPClient, + "NTPServer": NTPServer, +} + class PrimaiteGameOptions(BaseModel): """ @@ -239,54 +259,48 @@ class PrimaiteGame: new_service = None service_ref = service_cfg["ref"] service_type = service_cfg["type"] - service_types_mapping = { - "DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself. - "DNSServer": DNSServer, - "DatabaseClient": DatabaseClient, - "DatabaseService": DatabaseService, - "WebServer": WebServer, - "FTPClient": FTPClient, - "FTPServer": FTPServer, - "NTPClient": NTPClient, - "NTPServer": NTPServer, - } - if service_type in service_types_mapping: + if service_type in SERVICE_TYPES_MAPPING: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(service_types_mapping[service_type]) + new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type]) new_service = new_node.software_manager.software[service_type] game.ref_map_services[service_ref] = new_service.uuid else: _LOGGER.warning(f"service type not found {service_type}") # service-dependent options - if service_type == "DatabaseClient": + if service_type == "DNSClient": if "options" in service_cfg: opt = service_cfg["options"] - if "db_server_ip" in opt: - new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"])) + if "dns_server" in opt: + new_service.dns_server = IPv4Address(opt["dns_server"]) if service_type == "DNSServer": if "options" in service_cfg: opt = service_cfg["options"] if "domain_mapping" in opt: for domain, ip in opt["domain_mapping"].items(): - new_service.dns_register(domain, ip) + new_service.dns_register(domain, IPv4Address(ip)) if service_type == "DatabaseService": if "options" in service_cfg: opt = service_cfg["options"] - if "backup_server_ip" in opt: - new_service.configure_backup(backup_server=IPv4Address(opt["backup_server_ip"])) + new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) + new_service.start() + if service_type == "FTPServer": + if "options" in service_cfg: + opt = service_cfg["options"] + new_service.server_password = opt.get("server_password") + new_service.start() + if service_type == "NTPClient": + if "options" in service_cfg: + opt = service_cfg["options"] + new_service.ntp_server = IPv4Address(opt.get("ntp_server_ip")) new_service.start() - if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: new_application = None application_ref = application_cfg["ref"] application_type = application_cfg["type"] - application_types_mapping = { - "WebBrowser": WebBrowser, - "DataManipulationBot": DataManipulationBot, - } - if application_type in application_types_mapping: - new_node.software_manager.install(application_types_mapping[application_type]) + + if application_type in APPLICATION_TYPES_MAPPING: + new_node.software_manager.install(APPLICATION_TYPES_MAPPING[application_type]) new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: @@ -302,10 +316,30 @@ class PrimaiteGame: port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), ) + elif application_type == "DatabaseClient": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.configure( + server_ip_address=IPv4Address(opt.get("db_server_ip")), + server_password=opt.get("server_password"), + ) elif application_type == "WebBrowser": if "options" in application_cfg: opt = application_cfg["options"] new_application.target_url = opt.get("target_url") + + elif application_type == "DoSBot": + if "options" in application_cfg: + opt = application_cfg["options"] + new_application.configure( + target_ip_address=IPv4Address(opt.get("target_ip_address")), + target_port=Port(opt.get("target_port", Port.POSTGRES_SERVER.value)), + payload=opt.get("payload"), + repeat=bool(opt.get("repeat")), + port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), + dos_intensity=float(opt.get("dos_intensity", "1.0")), + max_sessions=int(opt.get("max_sessions", "1000")), + ) if "network_interfaces" in node_cfg: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml new file mode 100644 index 00000000..d1cec079 --- /dev/null +++ b/tests/assets/configs/basic_switched_network.yaml @@ -0,0 +1,148 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + +simulation: + network: + nodes: + + - ref: switch_1 + type: switch + hostname: switch_1 + num_ports: 8 + + - ref: client_1 + type: computer + hostname: client_1 + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.10 + server_password: arcd + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + - ref: dos_bot + type: DoSBot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + services: + - ref: client_1_dns_client + type: DNSClient + options: + dns_server: 192.168.1.10 + - ref: client_1_dns_server + type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.10 + - ref: client_1_database_service + type: DatabaseService + options: + backup_server_ip: 192.168.1.10 + - ref: client_1_web_service + type: WebServer + - ref: client_1_ftp_server + type: FTPServer + options: + server_password: arcd + - ref: client_1_ntp_client + type: NTPClient + options: + ntp_server_ip: 192.168.1.10 + - ref: client_1_ntp_server + type: NTPServer + - ref: client_2 + type: computer + hostname: client_2 + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + # pre installed services and applications + + links: + - ref: switch_1___client_1 + endpoint_a_ref: switch_1 + endpoint_a_port: 1 + endpoint_b_ref: client_1 + endpoint_b_port: 1 + - ref: switch_1___client_2 + endpoint_a_ref: switch_1 + endpoint_a_port: 2 + endpoint_b_ref: client_2 + endpoint_b_port: 1 diff --git a/tests/integration_tests/game_configuration.py b/tests/integration_tests/game_configuration.py new file mode 100644 index 00000000..3bd870e3 --- /dev/null +++ b/tests/integration_tests/game_configuration.py @@ -0,0 +1,220 @@ +from ipaddress import IPv4Address +from pathlib import Path +from typing import Union + +import yaml + +from primaite.config.load import example_config_path +from primaite.game.agent.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.interface import ProxyAgent, RandomAgent +from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot +from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.database.database_service import DatabaseService +from primaite.simulator.system.services.dns.dns_client import DNSClient +from primaite.simulator.system.services.dns.dns_server import DNSServer +from primaite.simulator.system.services.ftp.ftp_client import FTPClient +from primaite.simulator.system.services.ftp.ftp_server import FTPServer +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer +from primaite.simulator.system.services.web_server.web_server import WebServer +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_example_config(): + """Test that the example config can be parsed properly.""" + game = load_config(example_config_path()) + + assert len(game.agents) == 4 # red, blue and 2 green agents + + # green agent 1 + assert game.agents[0].agent_name == "client_2_green_user" + assert isinstance(game.agents[0], RandomAgent) + + # green agent 2 + assert game.agents[1].agent_name == "client_1_green_user" + assert isinstance(game.agents[1], RandomAgent) + + # red agent + assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot" + assert isinstance(game.agents[2], DataManipulationAgent) + + # blue agent + assert game.agents[3].agent_name == "defender" + assert isinstance(game.agents[3], ProxyAgent) + + network: Network = game.simulation.network + + assert len(network.nodes) == 10 # 10 nodes in example network + assert len(network.routers) == 1 # 1 router in network + assert len(network.switches) == 2 # 2 switches in network + assert len(network.servers) == 5 # 5 servers in network + + +def test_node_software_install(): + """Test that software can be installed on a node.""" + game = load_config(BASIC_CONFIG) + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_2: Computer = game.simulation.network.get_node_by_hostname("client_2") + + system_software = {DNSClient, FTPClient, NTPClient, WebBrowser} + + # check that system software is installed on client 1 + for software in system_software: + assert client_1.software_manager.software.get(software.__name__) is not None + + # check that system software is installed on client 2 + for software in system_software: + assert client_2.software_manager.software.get(software.__name__) is not None + + # check that applications have been installed on client 1 + for applications in APPLICATION_TYPES_MAPPING: + assert client_1.software_manager.software.get(applications) is not None + + # check that services have been installed on client 1 + for service in SERVICE_TYPES_MAPPING: + assert client_1.software_manager.software.get(service) is not None + + +def test_web_browser_install(): + """Test that the web browser can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + web_browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + + assert web_browser.target_url == "http://arcd.com/users/" + + +def test_database_client_install(): + """Test that the Database Client service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + database_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + + assert database_client.server_ip_address == IPv4Address("192.168.1.10") + assert database_client.server_password == "arcd" + + +def test_data_manipulation_bot_install(): + """Test that the data manipulation bot can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + + assert data_manipulation_bot.server_ip_address == IPv4Address("192.168.1.21") + assert data_manipulation_bot.payload == "DELETE" + assert data_manipulation_bot.data_manipulation_p_of_success == 0.8 + assert data_manipulation_bot.port_scan_p_of_success == 0.8 + assert data_manipulation_bot.server_password == "arcd" + + +def test_dos_bot_install(): + """Test that the denial of service bot can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + dos_bot: DoSBot = client_1.software_manager.software.get("DoSBot") + + assert dos_bot.target_ip_address == IPv4Address("192.168.10.21") + assert dos_bot.payload == "SPOOF DATA" + assert dos_bot.port_scan_p_of_success == 0.8 + assert dos_bot.dos_intensity == 1.0 # default + assert dos_bot.max_sessions == 1000 # default + assert dos_bot.repeat is False # default + + +def test_dns_client_install(): + """Test that the DNS Client service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + dns_client: DNSClient = client_1.software_manager.software.get("DNSClient") + + assert dns_client.dns_server == IPv4Address("192.168.1.10") + + +def test_dns_server_install(): + """Test that the DNS Client service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + dns_server: DNSServer = client_1.software_manager.software.get("DNSServer") + + assert dns_server.dns_lookup("arcd.com") == IPv4Address("192.168.1.10") + + +def test_database_service_install(): + """Test that the Database Service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + database_service: DatabaseService = client_1.software_manager.software.get("DatabaseService") + + assert database_service.backup_server_ip == IPv4Address("192.168.1.10") + + +def test_web_server_install(): + """Test that the Web Server Service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + web_server_service: WebServer = client_1.software_manager.software.get("WebServer") + + # config should have also installed database client - web server service should be able to retrieve this + assert web_server_service.software_manager.software.get("DatabaseClient") is not None + + +def test_ftp_client_install(): + """Test that the FTP Client Service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + ftp_client_service: FTPClient = client_1.software_manager.software.get("FTPClient") + assert ftp_client_service is not None + + +def test_ftp_server_install(): + """Test that the FTP Server Service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + ftp_server_service: FTPServer = client_1.software_manager.software.get("FTPServer") + assert ftp_server_service is not None + assert ftp_server_service.server_password == "arcd" + + +def test_ntp_client_install(): + """Test that the NTP Client Service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + ntp_client_service: NTPClient = client_1.software_manager.software.get("NTPClient") + assert ntp_client_service is not None + assert ntp_client_service.ntp_server == IPv4Address("192.168.1.10") + + +def test_ntp_server_install(): + """Test that the NTP Server Service can be configured via config.""" + game = load_config(BASIC_CONFIG) + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + ntp_server_service: NTPServer = client_1.software_manager.software.get("NTPServer") + assert ntp_server_service is not None