Merged PR 314: Add actions to enable/disable ports in routers/firewalls

## Summary
Added actions to enable/disable ports for routers/firewalls
Added tests for routers and added ticket in the backlog to test on firewalls (it's not possible to test at the moment because the game_and_agent fixture cannot build from a yaml file)

Improved the notebook for training sb3 agents on UC2 to include validated value for learning rate and fix for loading sb3 PPO from IY as well as code to evaluate agents

## Test process
Run new tests locally

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [x] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Related work items: #2403
This commit is contained in:
Cristian Genes
2024-03-25 12:34:21 +00:00
5 changed files with 141 additions and 9 deletions

View File

@@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
self.verb: str = "disable"
class NetworkPortAbstractAction(AbstractAction):
"""
Abstract base class for Port actions.
Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters
can inherit from this base class.
"""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
"""Init method for NetworkNICAbstractAction.
:param manager: Reference to the ActionManager which created this action.
:type manager: ActionManager
:param num_nodes: Number of nodes in the simulation.
:type num_nodes: int
:param max_nics_per_node: Maximum number of NICs per node.
:type max_nics_per_node: int
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, port_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id)
if node_name is None or port_num is None:
return ["do_nothing"]
return ["network", "node", node_name, "network_interface", port_num, self.verb]
class NetworkPortEnableAction(NetworkPortAbstractAction):
"""Action which enables a PORT."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb: str = "enable"
class NetworkPortDisableAction(NetworkPortAbstractAction):
"""Action which disables a PORT."""
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb: str = "disable"
class ActionManager:
"""Class which manages the action space for an agent."""
@@ -602,6 +649,8 @@ class ActionManager:
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
"NETWORK_PORT_ENABLE": NetworkPortEnableAction,
"NETWORK_PORT_DISABLE": NetworkPortDisableAction,
}
"""Dictionary which maps action type strings to the corresponding action class."""

View File

@@ -45,7 +45,13 @@
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3 import PPO"
"from stable_baselines3 import PPO\n",
"\n",
"EPISODE_LEN = 128\n",
"NO_STEPS = EPISODE_LEN * 10\n",
"BATCH_SIZE = EPISODE_LEN * 10\n",
"TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n",
"LEARNING_RATE = 3e-4"
]
},
{
@@ -54,7 +60,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = PPO('MlpPolicy', gym)\n"
"model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n"
]
},
{
@@ -63,7 +69,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.learn(total_timesteps=10)\n"
"model.learn(total_timesteps=TOTAL_TIMESTEPS)\n"
]
},
{
@@ -72,7 +78,7 @@
"metadata": {},
"outputs": [],
"source": [
"model.save(\"deleteme\")"
"model.save(\"PrimAITE-v3.0.0b7-PPO\")"
]
},
{
@@ -80,7 +86,21 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"eval_model = PPO(\"MlpPolicy\", gym)\n",
"eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"evaluate_policy(eval_model, gym, n_eval_episodes=10)"
]
}
],
"metadata": {
@@ -99,7 +119,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.9.18"
}
},
"nbformat": 4,

View File

@@ -519,12 +519,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC):
"""
super().enable()
try:
pass
self._connected_node.default_gateway_hello()
return True
except AttributeError:
pass
return False
return True
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:

View File

@@ -495,6 +495,8 @@ def game_and_agent():
{"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}},
{"type": "NETWORK_NIC_ENABLE"},
{"type": "NETWORK_NIC_DISABLE"},
{"type": "NETWORK_PORT_ENABLE"},
{"type": "NETWORK_PORT_DISABLE"},
]
action_space = ActionManager(
@@ -507,6 +509,7 @@ def game_and_agent():
},
{"node_name": "server_1", "services": [{"service_name": "DNSServer"}]},
{"node_name": "server_2", "services": [{"service_name": "WebServer"}]},
{"node_name": "router"},
],
max_folders_per_node=2,
max_files_per_folder=2,

View File

@@ -312,3 +312,65 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA
assert not client_1.file_system.get_file("downloads", "cat.png")
# 3.1 (but with the reference to the original file, we can check that deleted flag is True )
assert file.deleted
def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the NetworkPortDisableAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
# 1: Check that client_1 can access the network
client_1 = game.simulation.network.get_node_by_hostname("client_1")
server_1 = game.simulation.network.get_node_by_hostname("server_1")
router = game.simulation.network.get_node_by_hostname("router")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
browser.target_url = "http://www.example.com"
assert browser.get_webpage() # check that the browser can access example.com before we block it
# 2: Disable the NIC on client_1
action = (
"NETWORK_PORT_DISABLE",
{
"node_id": 3, # router
"port_id": 0, # port 1
},
)
agent.store_action(action)
game.step()
# 3: Check that the NIC is disabled, and that client 1 cannot access example.com
assert router.network_interface[1].enabled == False
assert not browser.get_webpage()
assert not client_1.ping("10.0.2.2")
assert not client_1.ping("10.0.2.3")
# 4: check that servers can still communicate
assert server_1.ping("10.0.2.3")
def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the NetworkPortEnableAction can form a request and that it is accepted by the simulation."""
game, agent = game_and_agent
# 1: Disable router port 1
router = game.simulation.network.get_node_by_hostname("router")
client_1 = game.simulation.network.get_node_by_hostname("client_1")
router.network_interface[1].disable()
assert not client_1.ping("10.0.2.2")
# 2: Use action to enable port
action = (
"NETWORK_PORT_ENABLE",
{
"node_id": 3, # router
"port_id": 0, # port 1
},
)
agent.store_action(action)
game.step()
# 3: Check that the Port is enabled, and that client 1 can ping again
assert router.network_interface[1].enabled == True
assert client_1.ping("10.0.2.3")