diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 092680a7..53b9b176 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable -from primaite.simulator.core import SimComponent +from primaite.simulator.core import Action, ActionManager, SimComponent from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol @@ -87,6 +87,36 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) + def _init_action_manager(self) -> ActionManager: + am = super()._init_action_manager() + + # When the request reaches this action, it should now contain solely positional args for the 'add_rule' action. + # POSITIONAL ARGUMENTS: + # 0: action (str name of an ACLAction) + # 1: protocol (str name of an IPProtocol) + # 2: source ip address (str castable to IPV4Address (e.g. '10.10.1.2')) + # 3: source port (str name of a Port (e.g. "HTTP")) # should we be using value, such as 80 or 443? + # 4: destination ip address (str castable to IPV4Address (e.g. '10.10.1.2')) + # 5: destination port (str name of a Port (e.g. "HTTP")) + # 6: position (int) + am.add_action( + "add_rule", + Action( + func=lambda request, context: self.add_rule( + ACLAction[request[0]], + IPProtocol[request[1]], + IPv4Address[request[2]], + Port[request[3]], + IPv4Address[request[4]], + Port[request[5]], + int(request[6]), + ) + ), + ) + + am.add_action("remove_rule", Action(func=lambda request, context: self.remove_rule(int(request[0])))) + return am + def describe_state(self) -> Dict: """ Describes the current state of the AccessControlList. @@ -596,6 +626,11 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp + def _init_action_manager(self) -> ActionManager: + am = super()._init_action_manager() + am.add_action("acl", Action(func=self.acl._action_manager)) + return am + def _get_port_of_nic(self, target_nic: NIC) -> Optional[int]: """ Retrieve the port number for a given NIC.