add actions to enable/disable ports in routers/firewalls, improve notebook for training PPO agents

This commit is contained in:
Cristian-VM2
2024-03-22 16:35:53 +00:00
parent cb9c14c87e
commit bef2bd8084
4 changed files with 140 additions and 6 deletions

View File

@@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
self.verb: str = "disable"
class NetworkPortAbstractAction(AbstractAction):
"""
Abstract base class for Port actions.
Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters
can inherit from this base class.
"""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkNICAbstractAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param num_nodes: Number of nodes in the simulation.
:type num_nodes: int
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, port_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id)
if node_name is None or port_num is None:
return ["do_nothing"]
return ["network", "node", node_name, "network_interface", port_num, self.verb]
class NetworkPortEnableAction(NetworkPortAbstractAction):
"""Action which enables a PORT."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb: str = "enable"
class NetworkPortDisableAction(NetworkPortAbstractAction):
"""Action which disables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb: str = "disable"
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -602,6 +649,8 @@ class ActionManager:
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""