add actions to enable/disable ports in routers/firewalls, improve notebook for training PPO agents
This commit is contained in:
@@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class NetworkPortAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for Port actions.
|
||||
|
||||
Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters
|
||||
can inherit from this base class.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param num_nodes: Number of nodes in the simulation.
|
||||
:type num_nodes: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, port_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
|
||||
port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id)
|
||||
if node_name is None or port_num is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_name, "network_interface", port_num, self.verb]
|
||||
|
||||
|
||||
class NetworkPortEnableAction(NetworkPortAbstractAction):
|
||||
"""Action which enables a PORT."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb: str = "enable"
|
||||
|
||||
|
||||
class NetworkPortDisableAction(NetworkPortAbstractAction):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -602,6 +649,8 @@ class ActionManager:
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -45,7 +45,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3 import PPO"
|
||||
"from stable_baselines3 import PPO\n",
|
||||
"\n",
|
||||
"EPISODE_LEN = 128\n",
|
||||
"NO_STEPS = EPISODE_LEN * 10\n",
|
||||
"BATCH_SIZE = EPISODE_LEN * 10\n",
|
||||
"TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n",
|
||||
"LEARNING_RATE = 3e-4"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -54,7 +60,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO('MlpPolicy', gym)\n"
|
||||
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -63,7 +69,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=10)\n"
|
||||
"model.learn(total_timesteps=TOTAL_TIMESTEPS)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -72,7 +78,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
"model.save(\"PrimAITE-v3.0.0b7-PPO\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -80,7 +86,21 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"eval_model = PPO(\"MlpPolicy\", gym)\n",
|
||||
"eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3.common.evaluation import evaluate_policy\n",
|
||||
"\n",
|
||||
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -99,7 +119,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user