add actions to enable/disable ports in routers/firewalls, improve notebook for training PPO agents
This commit is contained in:
@@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class NetworkPortAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for Port actions.
|
||||
|
||||
Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters
|
||||
can inherit from this base class.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param num_nodes: Number of nodes in the simulation.
|
||||
:type num_nodes: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, port_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
|
||||
port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id)
|
||||
if node_name is None or port_num is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_name, "network_interface", port_num, self.verb]
|
||||
|
||||
|
||||
class NetworkPortEnableAction(NetworkPortAbstractAction):
|
||||
"""Action which enables a PORT."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb: str = "enable"
|
||||
|
||||
|
||||
class NetworkPortDisableAction(NetworkPortAbstractAction):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -602,6 +649,8 @@ class ActionManager:
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -45,7 +45,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3 import PPO"
|
||||
"from stable_baselines3 import PPO\n",
|
||||
"\n",
|
||||
"EPISODE_LEN = 128\n",
|
||||
"NO_STEPS = EPISODE_LEN * 10\n",
|
||||
"BATCH_SIZE = EPISODE_LEN * 10\n",
|
||||
"TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n",
|
||||
"LEARNING_RATE = 3e-4"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -54,7 +60,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO('MlpPolicy', gym)\n"
|
||||
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -63,7 +69,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=10)\n"
|
||||
"model.learn(total_timesteps=TOTAL_TIMESTEPS)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -72,7 +78,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
"model.save(\"PrimAITE-v3.0.0b7-PPO\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -80,7 +86,21 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"eval_model = PPO(\"MlpPolicy\", gym)\n",
|
||||
"eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3.common.evaluation import evaluate_policy\n",
|
||||
"\n",
|
||||
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -99,7 +119,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -495,6 +495,8 @@ def game_and_agent():
|
||||
{"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}},
|
||||
{"type": "NETWORK_NIC_ENABLE"},
|
||||
{"type": "NETWORK_NIC_DISABLE"},
|
||||
{"type": "NETWORK_PORT_ENABLE"},
|
||||
{"type": "NETWORK_PORT_DISABLE"},
|
||||
]
|
||||
|
||||
action_space = ActionManager(
|
||||
@@ -507,6 +509,7 @@ def game_and_agent():
|
||||
},
|
||||
{"node_name": "server_1", "services": [{"service_name": "DNSServer"}]},
|
||||
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
|
||||
{"node_name": "router"},
|
||||
],
|
||||
max_folders_per_node=2,
|
||||
max_files_per_folder=2,
|
||||
|
||||
@@ -312,3 +312,65 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
assert not client_1.file_system.get_file("downloads", "cat.png")
|
||||
# 3.1 (but with the reference to the original file, we can check that deleted flag is True )
|
||||
assert file.deleted
|
||||
|
||||
|
||||
def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
"""Test that the NetworkPortDisableAction can form a request and that it is accepted by the simulation."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
# 1: Check that client_1 can access the network
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
server_1 = game.simulation.network.get_node_by_hostname("server_1")
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
action = (
|
||||
"NETWORK_PORT_DISABLE",
|
||||
{
|
||||
"node_id": 3, # router
|
||||
"port_id": 0, # port 1
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# 3: Check that the NIC is disabled, and that client 1 cannot access example.com
|
||||
assert router.network_interface[1].enabled == False
|
||||
assert not browser.get_webpage()
|
||||
assert not client_1.ping("10.0.2.2")
|
||||
assert not client_1.ping("10.0.2.3")
|
||||
|
||||
# 4: check that servers can still communicate
|
||||
assert server_1.ping("10.0.2.3")
|
||||
|
||||
|
||||
def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
"""Test that the NetworkPortEnableAction can form a request and that it is accepted by the simulation."""
|
||||
|
||||
game, agent = game_and_agent
|
||||
|
||||
# 1: Disable router port 1
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
router.network_interface[1].disable()
|
||||
assert not client_1.ping("10.0.2.2")
|
||||
|
||||
# 2: Use action to enable port
|
||||
action = (
|
||||
"NETWORK_PORT_ENABLE",
|
||||
{
|
||||
"node_id": 3, # router
|
||||
"port_id": 0, # port 1
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# 3: Check that the Port is enabled, and that client 1 can ping again
|
||||
assert router.network_interface[1].enabled == True
|
||||
assert client_1.ping("10.0.2.3")
|
||||
|
||||
Reference in New Issue
Block a user