diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 10c02b39..bfbefd3c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -223,8 +223,12 @@ class PrimaiteGame: sim = game.simulation net = sim.network - nodes_cfg = cfg["simulation"]["network"]["nodes"] - links_cfg = cfg["simulation"]["network"]["links"] + simulation_config = cfg.get("simulation", {}) + network_config = simulation_config.get("network", {}) + + nodes_cfg = network_config.get("nodes", []) + links_cfg = network_config.get("links", []) + for node_cfg in nodes_cfg: node_ref = node_cfg["ref"] n_type = node_cfg["type"] @@ -391,7 +395,7 @@ class PrimaiteGame: game.ref_map_links[link_cfg["ref"]] = new_link.uuid # 3. create agents - agents_cfg = cfg["agents"] + agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: agent_ref = agent_cfg["ref"] # noqa: F841 @@ -447,6 +451,6 @@ class PrimaiteGame: game.agents[agent_cfg["ref"]] = new_agent # Set the NMNE capture config - set_nmne_config(cfg["simulation"]["network"].get("nmne_config", {})) + set_nmne_config(network_config.get("nmne_config", {})) return game diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index b5a16430..6c2f38c5 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -8,10 +8,6 @@ from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.hardware.nodes.network.router import Router -from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service @@ -85,24 +81,29 @@ class Network(SimComponent): self.links[link_id].apply_timestep(timestep=timestep) @property - def routers(self) -> List[Router]: + def router_nodes(self) -> List[Node]: """The Routers in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Router)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Router"] @property - def switches(self) -> List[Switch]: + def switch_nodes(self) -> List[Node]: """The Switches in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Switch)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Switch"] @property - def computers(self) -> List[Computer]: + def computer_nodes(self) -> List[Node]: """The Computers in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Computer) and not isinstance(node, Server)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Computer"] @property - def servers(self) -> List[Server]: + def server_nodes(self) -> List[Node]: """The Servers in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Server)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Server"] + + @property + def firewall_nodes(self) -> List[Node]: + """The Firewalls in the Network.""" + return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] def show(self, nodes: bool = True, ip_addresses: bool = True, links: bool = True, markdown: bool = False): """ @@ -117,10 +118,11 @@ class Network(SimComponent): :param markdown: Use Markdown style in table output. Defaults to False. """ nodes_type_map = { - "Router": self.routers, - "Switch": self.switches, - "Server": self.servers, - "Computer": self.computers, + "Router": self.router_nodes, + "Firewall": self.firewall_nodes, + "Switch": self.switch_nodes, + "Server": self.server_nodes, + "Computer": self.computer_nodes, } if nodes: table = PrettyTable(["Node", "Type", "Operating State"]) @@ -143,7 +145,10 @@ class Network(SimComponent): for node in nodes: for i, port in node.network_interface.items(): if hasattr(port, "ip_address"): - table.add_row([node.hostname, i, port.ip_address, port.subnet_mask, node.default_gateway]) + port_str = port.port_name if port.port_name else port.port_num + table.add_row( + [node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] + ) print(table) if links: diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index f5ddcfad..d7b1dfd9 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -500,7 +500,7 @@ class Firewall(Router): if "ports" in cfg: internal_port = cfg["ports"]["internal_port"] external_port = cfg["ports"]["external_port"] - dmz_port = cfg["ports"]["dmz_port"] + dmz_port = cfg["ports"].get("dmz_port") # configure internal port firewall.configure_internal_port( @@ -514,11 +514,12 @@ class Firewall(Router): subnet_mask=IPV4Address(external_port.get("subnet_mask", "255.255.255.0")), ) - # configure dmz port - firewall.configure_dmz_port( - ip_address=IPV4Address(dmz_port.get("ip_address")), - subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), - ) + # configure dmz port if not none + if dmz_port is not None: + firewall.configure_dmz_port( + ip_address=IPV4Address(dmz_port.get("ip_address")), + subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), + ) if "acl" in cfg: # acl rules for internal_inbound_acl if cfg["acl"]["internal_inbound_acl"]: @@ -573,7 +574,7 @@ class Firewall(Router): ) # acl rules for external_inbound_acl - if cfg["acl"]["external_inbound_acl"]: + if cfg["acl"].get("external_inbound_acl"): for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): firewall.external_inbound_acl.add_rule( action=ACLAction[r_cfg["action"]], @@ -586,7 +587,7 @@ class Firewall(Router): ) # acl rules for external_outbound_acl - if cfg["acl"]["external_outbound_acl"]: + if cfg["acl"].get("external_outbound_acl"): for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): firewall.external_outbound_acl.add_rule( action=ACLAction[r_cfg["action"]], diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml new file mode 100644 index 00000000..71dc31a7 --- /dev/null +++ b/tests/assets/configs/basic_firewall.yaml @@ -0,0 +1,174 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# + +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: GreenWebBrowsingAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + +simulation: + network: + nodes: + + - ref: firewall + type: firewall + hostname: firewall + start_up_duration: 0 + shut_down_duration: 0 + ports: + external_port: # port 1 + ip_address: 192.168.20.1 + subnet_mask: 255.255.255.0 + internal_port: # port 2 + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + acl: + internal_inbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + internal_outbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + dmz_inbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + dmz_outbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + + - ref: switch_1 + type: switch + hostname: switch_1 + num_ports: 8 + - ref: switch_2 + type: switch + hostname: switch_2 + num_ports: 8 + + - ref: client_1 + type: computer + hostname: client_1 + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + # pre installed services and applications + - ref: client_2 + type: computer + hostname: client_2 + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + # pre installed services and applications + + links: + - ref: switch_1___client_1 + endpoint_a_ref: switch_1 + endpoint_a_port: 1 + endpoint_b_ref: client_1 + endpoint_b_port: 1 + - ref: switch_2___client_2 + endpoint_a_ref: switch_2 + endpoint_a_port: 1 + endpoint_b_ref: client_2 + endpoint_b_port: 1 + - ref: switch_1___firewall + endpoint_a_ref: switch_1 + endpoint_a_port: 2 + endpoint_b_ref: firewall + endpoint_b_port: 1 + - ref: switch_2___firewall + endpoint_a_ref: switch_2 + endpoint_a_port: 2 + endpoint_b_ref: firewall + endpoint_b_port: 2 diff --git a/tests/assets/configs/no_nodes_links_agents_network.yaml b/tests/assets/configs/no_nodes_links_agents_network.yaml new file mode 100644 index 00000000..607a899a --- /dev/null +++ b/tests/assets/configs/no_nodes_links_agents_network.yaml @@ -0,0 +1,31 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP diff --git a/tests/conftest.py b/tests/conftest.py index f94a886b..b60de730 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -422,7 +422,7 @@ def install_stuff_to_sim(sim: Simulation): assert len(sim.network.nodes) == 6 assert len(sim.network.links) == 5 # 5.1: Assert the router is correctly configured - r = sim.network.routers[0] + r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: assert acl_rule.src_port == acl_rule.dst_port == Port.DNS diff --git a/tests/integration_tests/configuration_file_parsing/__init__.py b/tests/integration_tests/configuration_file_parsing/__init__.py index 1c8481d6..be21c036 100644 --- a/tests/integration_tests/configuration_file_parsing/__init__.py +++ b/tests/integration_tests/configuration_file_parsing/__init__.py @@ -10,6 +10,8 @@ BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" DMZ_NETWORK = TEST_ASSETS_ROOT / "configs/dmz_network.yaml" +BASIC_FIREWALL = TEST_ASSETS_ROOT / "configs/basic_firewall.yaml" + def load_config(config_path: Union[str, Path]) -> PrimaiteGame: """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py index 2e0556e9..fc6e05ec 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -1,3 +1,5 @@ +from ipaddress import IPv4Address + import pytest from primaite.simulator.network.container import Network @@ -8,7 +10,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import ACLAction from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port -from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config +from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config @pytest.fixture(scope="function") @@ -17,6 +19,12 @@ def dmz_config() -> Network: return game.simulation.network +@pytest.fixture(scope="function") +def basic_firewall_config() -> Network: + game = load_config(BASIC_FIREWALL) + return game.simulation.network + + def test_firewall_is_in_configuration(dmz_config): """Test that the firewall exists in the configuration file.""" network: Network = dmz_config @@ -109,3 +117,19 @@ def test_firewall_acl_rules_correctly_added(dmz_config): # external_outbound should have implicit action PERMIT # ICMP does not have a provided ACL Rule but implicit action should allow anything assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT + + +def test_firewall_with_no_dmz_port(basic_firewall_config): + """ + Test to check that: + - the DMZ port can be ignored i.e. is optional. + - the external_outbound_acl and external_inbound_acl are optional + """ + network: Network = basic_firewall_config + + firewall: Firewall = network.get_node_by_hostname("firewall") + + assert firewall.dmz_port.ip_address == IPv4Address("127.0.0.1") + + assert firewall.external_outbound_acl.num_rules == 0 + assert firewall.external_inbound_acl.num_rules == 0 diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py index f23e7612..8797bf2e 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -11,9 +11,9 @@ def test_example_config(): network: Network = game.simulation.network assert len(network.nodes) == 10 # 10 nodes in example network - assert len(network.routers) == 1 # 1 router in network - assert len(network.switches) == 2 # 2 switches in network - assert len(network.servers) == 5 # 5 servers in network + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.switch_nodes) == 2 # 2 switches in network + assert len(network.server_nodes) == 5 # 5 servers in network def test_dmz_config(): @@ -23,9 +23,10 @@ def test_dmz_config(): network: Network = game.simulation.network assert len(network.nodes) == 9 # 9 nodes in network - assert len(network.routers) == 2 # 2 routers in network - assert len(network.switches) == 3 # 3 switches in network - assert len(network.servers) == 2 # 2 servers in network + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.firewall_nodes) == 1 # 1 firewall in network + assert len(network.switch_nodes) == 3 # 3 switches in network + assert len(network.server_nodes) == 2 # 2 servers in network def test_basic_config(): diff --git a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index f3dc51bd..7da66547 100644 --- a/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -60,9 +60,9 @@ def test_example_config(): network: Network = game.simulation.network assert len(network.nodes) == 10 # 10 nodes in example network - assert len(network.routers) == 1 # 1 router in network - assert len(network.switches) == 2 # 2 switches in network - assert len(network.servers) == 5 # 5 servers in network + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.switch_nodes) == 2 # 2 switches in network + assert len(network.server_nodes) == 5 # 5 servers in network def test_node_software_install(): diff --git a/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py b/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py new file mode 100644 index 00000000..5c9b0cb9 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py @@ -0,0 +1,19 @@ +import yaml + +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +CONFIG_FILE = TEST_ASSETS_ROOT / "configs" / "no_nodes_links_agents_network.yaml" + + +def test_no_nodes_links_agents_config(): + """Tests PrimaiteGame can be created from config file where there are no nodes, links, agents in the config file.""" + with open(CONFIG_FILE, "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + + network = game.simulation.network + + assert len(network.nodes) == 0 + assert len(network.links) == 0 diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 2cfc3f11..f0e386b8 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -26,10 +26,10 @@ def filter_keys_nested_item(data, keys): @pytest.fixture(scope="function") def network(example_network) -> Network: - assert len(example_network.routers) is 1 - assert len(example_network.switches) is 2 - assert len(example_network.computers) is 2 - assert len(example_network.servers) is 2 + assert len(example_network.router_nodes) is 1 + assert len(example_network.switch_nodes) is 2 + assert len(example_network.computer_nodes) is 2 + assert len(example_network.server_nodes) is 2 example_network.show()