Make actions more recursive

This commit is contained in:
Marek Wolan
2023-09-05 14:51:04 +01:00
parent 0c362da789
commit 2b68ed813c
5 changed files with 42 additions and 30 deletions

View File

@@ -48,6 +48,9 @@ class Action(BaseModel):
that invokes a class method of your SimComponent. For example if the component is a node and the action is for
turning it off, then the SimComponent should have a turn_off(self) method that does not need to accept any args.
Then, this Action will be given something like ``func = lambda request, context: self.turn_off()``.
``func`` can also be another action manager, since ActionManager is a callable with a signature that matches what is
expected by ``func``.
"""
validator: ActionPermissionValidator = AllowAllValidator()
"""
@@ -68,8 +71,9 @@ class ActionManager(BaseModel):
actions: Dict[str, Action] = {}
"""maps action verb to an action object."""
def process_request(self, request: List[str], context: Dict) -> None:
"""Process an action request.
def __call__(self, request: List[str], context: Dict) -> None:
"""
Process an action request.
:param request: A list of strings which specify what action to take. The first string must be one of the allowed
actions, i.e. it must be a key of self.actions. The subsequent strings in the list are passed as parameters
@@ -99,7 +103,8 @@ class ActionManager(BaseModel):
action.func(action_options, context)
def add_action(self, name: str, action: Action) -> None:
"""Add an action to this action manager.
"""
Add an action to this action manager.
:param name: The string associated to this action.
:type name: str
@@ -113,10 +118,32 @@ class ActionManager(BaseModel):
self.actions[name] = action
def list_actions(self) -> List[List[str]]:
def remove_action(self, name: str) -> None:
"""
Remove an action from this manager.
:param name: name identifier of the action
:type name: str
"""
if name not in self.actions:
msg = f"Attempted to remove action {name} from action manager, but it was not registered."
_LOGGER.error(msg)
raise RuntimeError(msg)
self.actions.pop(name)
def get_action_tree(self) -> List[List[str]]:
"""Recursively generate action tree for this component."""
actions = []
for act_name, act in self.actions.items():
pass # TODO:
if isinstance(act.func, ActionManager):
sub_actions = act.func.get_action_tree()
sub_actions = [[act_name]+a for a in sub_actions]
actions.extend(sub_actions)
else:
actions.append([act_name])
return actions
class SimComponent(BaseModel):
@@ -196,7 +223,7 @@ class SimComponent(BaseModel):
"""
if self.action_manager is None:
return
self.action_manager.process_request(action, context)
self.action_manager(action, context)
def apply_timestep(self, timestep: int) -> None:
"""

View File

@@ -46,13 +46,7 @@ class AccountGroup(Enum):
class GroupMembershipValidator(ActionPermissionValidator):
"""Permit actions based on group membership."""
def __init__(self, allowed_groups: List[AccountGroup]) -> None:
"""Store a list of groups that should be granted permission.
:param allowed_groups: List of AccountGroups that are permitted to perform some action.
:type allowed_groups: List[AccountGroup]
"""
self.allowed_groups = allowed_groups
allowed_groups:List[AccountGroup]
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
@@ -93,7 +87,7 @@ class DomainController(SimComponent):
"account",
Action(
func=lambda request, context: self.accounts[request.pop(0)].apply_action(request, context),
validator=GroupMembershipValidator([AccountGroup.DOMAIN_ADMIN]),
validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]),
),
)
return am

View File

@@ -43,12 +43,12 @@ class Network(SimComponent):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
self._node_action_manager = ActionManager()
am.add_action(
"node",
Action(
func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
validator=AllowAllValidator(),
func = self._node_action_manager
# func=lambda request, context: self.nodes[request.pop(0)].apply_action(request, context),
),
)
return am
@@ -182,6 +182,7 @@ class Network(SimComponent):
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
self._node_action_manager.add_action(name = node.uuid, action = Action(func=node._action_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:
"""
@@ -211,6 +212,7 @@ class Network(SimComponent):
self.nodes.pop(node.uuid)
node.parent = None
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
self._node_action_manager.remove_action(name = node.uuid)
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
"""

View File

@@ -135,7 +135,6 @@ class NIC(SimComponent):
{
"ip_adress": str(self.ip_address),
"subnet_mask": str(self.subnet_mask),
"gateway": str(self.gateway),
"mac_address": self.mac_address,
"speed": self.speed,
"mtu": self.mtu,

View File

@@ -24,19 +24,9 @@ class Simulation(SimComponent):
def _init_action_manager(self) -> ActionManager:
am = super()._init_action_manager()
# pass through network actions to the network objects
am.add_action(
"network",
Action(
func=lambda request, context: self.network.apply_action(request, context), validator=AllowAllValidator()
),
)
am.add_action("network", Action(func=self.network._action_manager))
# pass through domain actions to the domain object
am.add_action(
"domain",
Action(
func=lambda request, context: self.domain.apply_action(request, context), validator=AllowAllValidator()
),
)
am.add_action("domain", Action(func=self.domain._action_manager))
return am
def describe_state(self) -> Dict: