add actions to enable/disable ports in routers/firewalls, improve notebook for training PPO agents

This commit is contained in:
Cristian-VM2
2024-03-22 16:35:53 +00:00
parent cb9c14c87e
commit bef2bd8084
4 changed files with 140 additions and 6 deletions

View File

@@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
self.verb: str = "disable"
class NetworkPortAbstractAction(AbstractAction):
"""
Abstract base class for Port actions.
Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters
can inherit from this base class.
"""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkNICAbstractAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param num_nodes: Number of nodes in the simulation.
:type num_nodes: int
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, port_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id)
if node_name is None or port_num is None:
return ["do_nothing"]
return ["network", "node", node_name, "network_interface", port_num, self.verb]
class NetworkPortEnableAction(NetworkPortAbstractAction):
"""Action which enables a PORT."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb: str = "enable"
class NetworkPortDisableAction(NetworkPortAbstractAction):
"""Action which disables a NIC."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb: str = "disable"
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -602,6 +649,8 @@ class ActionManager:
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -45,7 +45,13 @@
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
"from stable_baselines3 import PPO\n",
"\n",
"EPISODE_LEN = 128\n",
"NO_STEPS = EPISODE_LEN * 10\n",
"BATCH_SIZE = EPISODE_LEN * 10\n",
"TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n",
"LEARNING_RATE = 3e-4"
]
},
{
@@ -54,7 +60,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n"
]
},
{
@@ -63,7 +69,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=10)\n"
"model.learn(total_timesteps=TOTAL_TIMESTEPS)\n"
]
},
{
@@ -72,7 +78,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
"model.save(\"PrimAITE-v3.0.0b7-PPO\")"
]
},
{
@@ -80,7 +86,21 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"eval_model = PPO(\"MlpPolicy\", gym)\n",
"eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
]
}
],
"metadata": {
@@ -99,7 +119,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.9.18"
}
},
"nbformat": 4,

View File

@@ -495,6 +495,8 @@ def game_and_agent():
{"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}},
{"type": "NETWORK_NIC_ENABLE"},
{"type": "NETWORK_NIC_DISABLE"},
{"type": "NETWORK_PORT_ENABLE"},
{"type": "NETWORK_PORT_DISABLE"},
]
action_space = ActionManager(
@@ -507,6 +509,7 @@ def game_and_agent():
},
{"node_name": "server_1", "services": [{"service_name": "DNSServer"}]},
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
{"node_name": "router"},
],
max_folders_per_node=2,
max_files_per_folder=2,

View File

@@ -312,3 +312,65 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
assert not client_1.file_system.get_file("downloads", "cat.png")
# 3.1 (but with the reference to the original file, we can check that deleted flag is True )
assert file.deleted
def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the NetworkPortDisableAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
# 1: Check that client_1 can access the network
client_1 = game.simulation.network.get_node_by_hostname("client_1")
server_1 = game.simulation.network.get_node_by_hostname("server_1")
router = game.simulation.network.get_node_by_hostname("router")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
action = (
"NETWORK_PORT_DISABLE",
{
"node_id": 3, # router
"port_id": 0, # port 1
},
)
agent.store_action(action)
game.step()
# 3: Check that the NIC is disabled, and that client 1 cannot access example.com
assert router.network_interface[1].enabled == False
assert not browser.get_webpage()
assert not client_1.ping("10.0.2.2")
assert not client_1.ping("10.0.2.3")
# 4: check that servers can still communicate
assert server_1.ping("10.0.2.3")
def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the NetworkPortEnableAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
# 1: Disable router port 1
router = game.simulation.network.get_node_by_hostname("router")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
router.network_interface[1].disable()
assert not client_1.ping("10.0.2.2")
# 2: Use action to enable port
action = (
"NETWORK_PORT_ENABLE",
{
"node_id": 3, # router
"port_id": 0, # port 1
},
)
agent.store_action(action)
game.step()
# 3: Check that the Port is enabled, and that client 1 can ping again
assert router.network_interface[1].enabled == True
assert client_1.ping("10.0.2.3")