From 2b68ed813c595576ff8dc091fdd8f60796467490 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 5 Sep 2023 14:51:04 +0100 Subject: [PATCH] Make actions more recursive --- src/primaite/simulator/core.py | 39 ++++++++++++++++--- src/primaite/simulator/domain/controller.py | 10 +---- src/primaite/simulator/network/container.py | 8 ++-- .../simulator/network/hardware/base.py | 1 - src/primaite/simulator/sim_container.py | 14 +------ 5 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index c38a7e2f..e8cd4b98 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -48,6 +48,9 @@ class Action(BaseModel): that invokes a class method of your SimComponent. For example if the component is a node and the action is for turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args. Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``. + + ``func`` can also be another action manager, since ActionManager is a callable with a signature that matches what is + expected by ``func``. """ validator: ActionPermissionValidator = AllowAllValidator() """ @@ -68,8 +71,9 @@ class ActionManager(BaseModel): actions: Dict[str, Action] = {} """maps action verb to an action object.""" - def process_request(self, request: List[str], context: Dict) -> None: - """Process an action request. + def __call__(self, request: List[str], context: Dict) -> None: + """ + Process an action request. :param request: A list of strings which specify what action to take. The first string must be one of the allowed actions, i.e. it must be a key of self.actions. The subsequent strings in the list are passed as parameters @@ -99,7 +103,8 @@ class ActionManager(BaseModel): action.func(action_options, context) def add_action(self, name: str, action: Action) -> None: - """Add an action to this action manager. + """ + Add an action to this action manager. :param name: The string associated to this action. :type name: str @@ -113,10 +118,32 @@ class ActionManager(BaseModel): self.actions[name] = action - def list_actions(self) -> List[List[str]]: + def remove_action(self, name: str) -> None: + """ + Remove an action from this manager. + + :param name: name identifier of the action + :type name: str + """ + if name not in self.actions: + msg = f"Attempted to remove action {name} from action manager, but it was not registered." + _LOGGER.error(msg) + raise RuntimeError(msg) + + self.actions.pop(name) + + + def get_action_tree(self) -> List[List[str]]: + """Recursively generate action tree for this component.""" actions = [] for act_name, act in self.actions.items(): - pass # TODO: + if isinstance(act.func, ActionManager): + sub_actions = act.func.get_action_tree() + sub_actions = [[act_name]+a for a in sub_actions] + actions.extend(sub_actions) + else: + actions.append([act_name]) + return actions class SimComponent(BaseModel): @@ -196,7 +223,7 @@ class SimComponent(BaseModel): """ if self.action_manager is None: return - self.action_manager.process_request(action, context) + self.action_manager(action, context) def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/domain/controller.py b/src/primaite/simulator/domain/controller.py index b436ca79..cd0fe9de 100644 --- a/src/primaite/simulator/domain/controller.py +++ b/src/primaite/simulator/domain/controller.py @@ -46,13 +46,7 @@ class AccountGroup(Enum): class GroupMembershipValidator(ActionPermissionValidator): """Permit actions based on group membership.""" - def __init__(self, allowed_groups: List[AccountGroup]) -> None: - """Store a list of groups that should be granted permission. - - :param allowed_groups: List of AccountGroups that are permitted to perform some action. - :type allowed_groups: List[AccountGroup] - """ - self.allowed_groups = allowed_groups + allowed_groups:List[AccountGroup] def __call__(self, request: List[str], context: Dict) -> bool: """Permit the action if the request comes from an account which belongs to the right group.""" @@ -93,7 +87,7 @@ class DomainController(SimComponent): "account", Action( func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context), - validator=GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]), + validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]), ), ) return am diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 4d1afe72..1c7bbec7 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -43,12 +43,12 @@ class Network(SimComponent): def _init_action_manager(self) -> ActionManager: am = super()._init_action_manager() - + self._node_action_manager = ActionManager() am.add_action( "node", Action( - func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context), - validator=AllowAllValidator(), + func = self._node_action_manager + # func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context), ), ) return am @@ -182,6 +182,7 @@ class Network(SimComponent): node.parent = self self._nx_graph.add_node(node.hostname) _LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}") + self._node_action_manager.add_action(name = node.uuid, action = Action(func=node._action_manager)) def get_node_by_hostname(self, hostname: str) -> Optional[Node]: """ @@ -211,6 +212,7 @@ class Network(SimComponent): self.nodes.pop(node.uuid) node.parent = None _LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}") + self._node_action_manager.remove_action(name = node.uuid) def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None: """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 101d6b72..a846f7e2 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -135,7 +135,6 @@ class NIC(SimComponent): { "ip_adress": str(self.ip_address), "subnet_mask": str(self.subnet_mask), - "gateway": str(self.gateway), "mac_address": self.mac_address, "speed": self.speed, "mtu": self.mtu, diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 2a5123f3..d647b0bc 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -24,19 +24,9 @@ class Simulation(SimComponent): def _init_action_manager(self) -> ActionManager: am = super()._init_action_manager() # pass through network actions to the network objects - am.add_action( - "network", - Action( - func=lambda request, context: self.network.apply_action(request, context), validator=AllowAllValidator() - ), - ) + am.add_action("network", Action(func=self.network._action_manager)) # pass through domain actions to the domain object - am.add_action( - "domain", - Action( - func=lambda request, context: self.domain.apply_action(request, context), validator=AllowAllValidator() - ), - ) + am.add_action("domain", Action(func=self.domain._action_manager)) return am def describe_state(self) -> Dict: