This commit is contained in:
Marek Wolan
2024-01-25 13:13:50 +00:00
parent e7aac754a0
commit 73a75c497b
7 changed files with 140 additions and 7 deletions

View File

@@ -90,7 +90,7 @@ class AccessControlList(SimComponent):
implicit_rule: ACLRule
max_acl_rules: int = 25
_acl: List[Optional[ACLRule]] = [None] * 24
_default_config: dict[int, dict] = {}
_default_config: Dict[int, dict] = {}
"""Config dict describing how the ACL list should look at episode start"""
def __init__(self, **kwargs) -> None:
@@ -109,6 +109,21 @@ class AccessControlList(SimComponent):
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
for i, rule in enumerate(self._acl):
if not rule:
continue
self._default_config[i] = {"action": rule.action.name}
if rule.src_ip_address:
self._default_config[i]["src_ip"] = str(rule.src_ip_address)
if rule.dst_ip_address:
self._default_config[i]["dst_ip"] = str(rule.dst_ip_address)
if rule.src_port:
self._default_config[i]["src_port"] = rule.src_port.name
if rule.dst_port:
self._default_config[i]["dst_port"] = rule.dst_port.name
if rule.protocol:
self._default_config[i]["protocol"] = rule.protocol.name
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
@@ -124,8 +139,8 @@ class AccessControlList(SimComponent):
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
src_ip_address=r_cfg.get("ip_address"),
dst_ip_address=r_cfg.get("ip_address"),
src_ip_address=r_cfg.get("src_ip"),
dst_ip_address=r_cfg.get("dst_ip"),
position=r_num,
)

View File

@@ -491,6 +491,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -502,6 +502,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -509,6 +509,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:
@@ -940,6 +957,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -507,6 +507,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -503,6 +503,23 @@ agents:
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
reward_function:
reward_components:

View File

@@ -10,6 +10,22 @@ from primaite.simulator.system.applications.database_client import DatabaseClien
from primaite.simulator.system.services.database.database_service import DatabaseService
def filter_keys_nested_item(data, keys):
stack = [(data, {})]
while stack:
current, filtered = stack.pop()
if isinstance(current, dict):
for k, v in current.items():
if k in keys:
filtered[k] = filter_keys_nested_item(v, keys)
elif isinstance(v, (dict, list)):
stack.append((v, {}))
elif isinstance(current, list):
for item in current:
stack.append((item, {}))
return filtered
@pytest.fixture(scope="function")
def network(example_network) -> Network:
assert len(example_network.routers) is 1
@@ -59,10 +75,10 @@ def test_reset_network(network):
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
state_before, sort_keys=True, indent=2
)
# don't worry if UUIDs change
a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"])
b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"])
assert a == b
def test_creating_container():