From fabd4fd5ddd957ff8abfa3a9d361d04d743fb490 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 4 Oct 2023 09:07:04 +0100 Subject: [PATCH] Add ACL Action to game layer --- src/primaite/game/agent/actions.py | 48 ++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index cb7061fc..6c4ae3b2 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -227,11 +227,39 @@ class NodeResetAction(NodeAbstractAction): self.verb = 'reset' class NetworkACLAddRuleAction(AbstractAction): - def __init__(self, manager: "ActionManager", **kwargs) -> None: + def __init__(self, + manager: "ActionManager", + target_router_uuid:str, + max_acl_rules:int, + num_ips:int, + num_ports:int, + num_protocols:int, + **kwargs) -> None: super().__init__(manager=manager) num_permissions = 2 - self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_nics, num_nics, num_ports, num_ports, num_protocols) + self.shape: Tuple[int] = (max_acl_rules, num_permissions, num_ips, num_ips, num_ports, num_ports, num_protocols) + self.target_router_uuid:str = target_router_uuid + def form_request(self, position, permission, source_ip_idx, dest_ip_idx, source_port_idx, dest_port_idx, protocol_idx) -> List[str]: + protocol = self.manager.get_internet_protocol_by_idx(protocol_idx) + src_ip = self.manager.get_ip_address_by_idx(source_ip_idx) + src_port = self.manager.get_port_by_idx(source_port_idx) + dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx) + dst_port = self.manager.get_port_by_idx(dest_port_idx) + return [ + 'network', + 'node', + self.target_router_uuid, + 'acl', + 'add_rule', + permission, + protocol, + src_ip, + src_port, + dst_ip, + dst_port, + position + ] @@ -289,9 +317,14 @@ class ActionManager: max_services_per_node:int = 2, max_nics_per_node:int=8, max_acl_rules:int=10, + protocols:List[str]=['TCP','UDP','ICMP'], + ports:List[str]=['HTTP','DNS','ARP','FTP'], + ip_address_list:Optional[List[str]]=None, act_map:Optional[Dict[int, Dict]]=None) -> None: self.sim: Simulation = sim self.node_uuids:List[str] = node_uuids + self.protocols:List[str] = protocols + self.ports:List[str] = ports action_args = { "num_nodes": len(node_uuids), @@ -299,7 +332,10 @@ class ActionManager: "num_files": max_files_per_folder, "num_services": max_services_per_node, "num_nics": max_nics_per_node, - "num_acl_rules": max_acl_rules} + "num_acl_rules": max_acl_rules, + "num_protocols": len(self.protocols), + "num_ports": len(self.protocols), + "num_ips":} self.actions: Dict[str, AbstractAction] = {} for act_type in actions: self.actions[act_type] = self.__act_class_identifiers[act_type](self, **action_args) @@ -362,8 +398,14 @@ class ActionManager: service_uuids = list(node.services.keys()) return service_uuids[service_idx] if len(service_uuids)>service_idx else None + def get_internet_protocol_by_idx(self, protocol_idx:int) -> str: + # protocol = self.manager.get_internet_protocol_by_idx(protocol_idx) + # src_ip = self.manager.get_ip_address_by_idx(source_ip_idx) + # src_port = self.manager.get_port_by_idx(source_port_idx) + # dst_ip = self.manager.get_ip_address_by_idx(dest_ip_idx) + # dst_port = self.manager.get_port_by_idx(dest_port_idx)