Merged PR 314: Add actions to enable/disable ports in routers/firewalls
## Summary Added actions to enable/disable ports for routers/firewalls Added tests for routers and added ticket in the backlog to test on firewalls (it's not possible to test at the moment because the game_and_agent fixture cannot build from a yaml file) Improved the notebook for training sb3 agents on UC2 to include validated value for learning rate and fix for loading sb3 PPO from IY as well as code to evaluate agents ## Test process Run new tests locally ## Checklist - [x] PR is linked to a **work item** - [x] **acceptance criteria** of linked ticket are met - [x] performed **self-review** of the code - [x] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [x] ran **pre-commit** checks for code style - [x] attended to any **TO-DOs** left in the code Related work items: #2403
This commit is contained in:
@@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class NetworkPortAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for Port actions.
|
||||
|
||||
Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters
|
||||
can inherit from this base class.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param num_nodes: Number of nodes in the simulation.
|
||||
:type num_nodes: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node}
|
||||
self.verb: str # define but don't initialise: defends against children classes not defining this
|
||||
|
||||
def form_request(self, node_id: int, port_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
|
||||
port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id)
|
||||
if node_name is None or port_num is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_name, "network_interface", port_num, self.verb]
|
||||
|
||||
|
||||
class NetworkPortEnableAction(NetworkPortAbstractAction):
|
||||
"""Action which enables a PORT."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb: str = "enable"
|
||||
|
||||
|
||||
class NetworkPortDisableAction(NetworkPortAbstractAction):
|
||||
"""Action which disables a PORT."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb: str = "disable"
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
@@ -602,6 +649,8 @@ class ActionManager:
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
|
||||
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
|
||||
@@ -45,7 +45,13 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3 import PPO"
|
||||
"from stable_baselines3 import PPO\n",
|
||||
"\n",
|
||||
"EPISODE_LEN = 128\n",
|
||||
"NO_STEPS = EPISODE_LEN * 10\n",
|
||||
"BATCH_SIZE = EPISODE_LEN * 10\n",
|
||||
"TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n",
|
||||
"LEARNING_RATE = 3e-4"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -54,7 +60,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = PPO('MlpPolicy', gym)\n"
|
||||
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -63,7 +69,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.learn(total_timesteps=10)\n"
|
||||
"model.learn(total_timesteps=TOTAL_TIMESTEPS)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -72,7 +78,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.save(\"deleteme\")"
|
||||
"model.save(\"PrimAITE-v3.0.0b7-PPO\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -80,7 +86,21 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"eval_model = PPO(\"MlpPolicy\", gym)\n",
|
||||
"eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from stable_baselines3.common.evaluation import evaluate_policy\n",
|
||||
"\n",
|
||||
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -99,7 +119,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -519,12 +519,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
|
||||
"""
|
||||
super().enable()
|
||||
try:
|
||||
pass
|
||||
self._connected_node.default_gateway_hello()
|
||||
return True
|
||||
except AttributeError:
|
||||
pass
|
||||
return False
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
|
||||
@@ -495,6 +495,8 @@ def game_and_agent():
|
||||
{"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}},
|
||||
{"type": "NETWORK_NIC_ENABLE"},
|
||||
{"type": "NETWORK_NIC_DISABLE"},
|
||||
{"type": "NETWORK_PORT_ENABLE"},
|
||||
{"type": "NETWORK_PORT_DISABLE"},
|
||||
]
|
||||
|
||||
action_space = ActionManager(
|
||||
@@ -507,6 +509,7 @@ def game_and_agent():
|
||||
},
|
||||
{"node_name": "server_1", "services": [{"service_name": "DNSServer"}]},
|
||||
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
|
||||
{"node_name": "router"},
|
||||
],
|
||||
max_folders_per_node=2,
|
||||
max_files_per_folder=2,
|
||||
|
||||
@@ -312,3 +312,65 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
|
||||
assert not client_1.file_system.get_file("downloads", "cat.png")
|
||||
# 3.1 (but with the reference to the original file, we can check that deleted flag is True )
|
||||
assert file.deleted
|
||||
|
||||
|
||||
def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
"""Test that the NetworkPortDisableAction can form a request and that it is accepted by the simulation."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
# 1: Check that client_1 can access the network
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
server_1 = game.simulation.network.get_node_by_hostname("server_1")
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
browser.target_url = "http://www.example.com"
|
||||
assert browser.get_webpage() # check that the browser can access example.com before we block it
|
||||
|
||||
# 2: Disable the NIC on client_1
|
||||
action = (
|
||||
"NETWORK_PORT_DISABLE",
|
||||
{
|
||||
"node_id": 3, # router
|
||||
"port_id": 0, # port 1
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# 3: Check that the NIC is disabled, and that client 1 cannot access example.com
|
||||
assert router.network_interface[1].enabled == False
|
||||
assert not browser.get_webpage()
|
||||
assert not client_1.ping("10.0.2.2")
|
||||
assert not client_1.ping("10.0.2.3")
|
||||
|
||||
# 4: check that servers can still communicate
|
||||
assert server_1.ping("10.0.2.3")
|
||||
|
||||
|
||||
def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
"""Test that the NetworkPortEnableAction can form a request and that it is accepted by the simulation."""
|
||||
|
||||
game, agent = game_and_agent
|
||||
|
||||
# 1: Disable router port 1
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
router.network_interface[1].disable()
|
||||
assert not client_1.ping("10.0.2.2")
|
||||
|
||||
# 2: Use action to enable port
|
||||
action = (
|
||||
"NETWORK_PORT_ENABLE",
|
||||
{
|
||||
"node_id": 3, # router
|
||||
"port_id": 0, # port 1
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
# 3: Check that the Port is enabled, and that client 1 can ping again
|
||||
assert router.network_interface[1].enabled == True
|
||||
assert client_1.ping("10.0.2.3")
|
||||
|
||||
Reference in New Issue
Block a user