From 73a75c497b612bbacb352373070c7391205d4c73 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 25 Jan 2024 13:13:50 +0000 Subject: [PATCH] Fix test --- .../network/hardware/nodes/router.py | 21 ++++++++++-- .../assets/configs/bad_primaite_session.yaml | 17 ++++++++++ .../configs/eval_only_primaite_session.yaml | 17 ++++++++++ tests/assets/configs/multi_agent_session.yaml | 34 +++++++++++++++++++ .../assets/configs/test_primaite_session.yaml | 17 ++++++++++ .../configs/train_only_primaite_session.yaml | 17 ++++++++++ .../_simulator/_network/test_container.py | 24 ++++++++++--- 7 files changed, 140 insertions(+), 7 deletions(-) diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 0c5d0ce9..41c14967 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -90,7 +90,7 @@ class AccessControlList(SimComponent): implicit_rule: ACLRule max_acl_rules: int = 25 _acl: List[Optional[ACLRule]] = [None] * 24 - _default_config: dict[int, dict] = {} + _default_config: Dict[int, dict] = {} """Config dict describing how the ACL list should look at episode start""" def __init__(self, **kwargs) -> None: @@ -109,6 +109,21 @@ class AccessControlList(SimComponent): vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) + for i, rule in enumerate(self._acl): + if not rule: + continue + self._default_config[i] = {"action": rule.action.name} + if rule.src_ip_address: + self._default_config[i]["src_ip"] = str(rule.src_ip_address) + if rule.dst_ip_address: + self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) + if rule.src_port: + self._default_config[i]["src_port"] = rule.src_port.name + if rule.dst_port: + self._default_config[i]["dst_port"] = rule.dst_port.name + if rule.protocol: + self._default_config[i]["protocol"] = rule.protocol.name + def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.implicit_rule.reset_component_for_episode(episode) @@ -124,8 +139,8 @@ class AccessControlList(SimComponent): src_port=None if not (p := r_cfg.get("src_port")) else Port[p], dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("ip_address"), - dst_ip_address=r_cfg.get("ip_address"), + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), position=r_num, ) diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index e5458670..4a1fc275 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -491,6 +491,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 767279ce..c8ffa23f 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -502,6 +502,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index 6290fa53..6cd22694 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -509,6 +509,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: @@ -940,6 +957,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 89b88475..99087798 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -507,6 +507,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index b9fa1216..c2842a06 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -503,6 +503,23 @@ agents: max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_ref: domain_controller + nic_num: 1 + - node_ref: web_server + nic_num: 1 + - node_ref: database_server + nic_num: 1 + - node_ref: backup_server + nic_num: 1 + - node_ref: security_suite + nic_num: 1 + - node_ref: client_1 + nic_num: 1 + - node_ref: client_2 + nic_num: 1 + - node_ref: security_suite + nic_num: 2 reward_function: reward_components: diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index e348838e..7667a59f 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien from primaite.simulator.system.services.database.database_service import DatabaseService +def filter_keys_nested_item(data, keys): + stack = [(data, {})] + while stack: + current, filtered = stack.pop() + if isinstance(current, dict): + for k, v in current.items(): + if k in keys: + filtered[k] = filter_keys_nested_item(v, keys) + elif isinstance(v, (dict, list)): + stack.append((v, {})) + elif isinstance(current, list): + for item in current: + stack.append((item, {})) + return filtered + + @pytest.fixture(scope="function") def network(example_network) -> Network: assert len(example_network.routers) is 1 @@ -59,10 +75,10 @@ def test_reset_network(network): assert client_1.operating_state is NodeOperatingState.ON assert server_1.operating_state is NodeOperatingState.ON - - assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps( - state_before, sort_keys=True, indent=2 - ) + # don't worry if UUIDs change + a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) + b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) + assert a == b def test_creating_container():