diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 4d28328e..af90c1e1 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -569,6 +569,53 @@ class NetworkNICDisableAction(NetworkNICAbstractAction): self.verb: str = "disable" +class NetworkPortAbstractAction(AbstractAction): + """ + Abstract base class for Port actions. + + Any action which applies to a Router/Firewall and uses node_id and port_id as its only two parameters + can inherit from this base class. + """ + + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + """Init method for NetworkNICAbstractAction. + + :param manager: Reference to the ActionManager which created this action. + :type manager: ActionManager + :param num_nodes: Number of nodes in the simulation. + :type num_nodes: int + :param max_nics_per_node: Maximum number of NICs per node. + :type max_nics_per_node: int + """ + super().__init__(manager=manager) + self.shape: Dict[str, int] = {"node_id": num_nodes, "port_id": max_nics_per_node} + self.verb: str # define but don't initialise: defends against children classes not defining this + + def form_request(self, node_id: int, port_id: int) -> List[str]: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_idx=node_id) + port_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=port_id) + if node_name is None or port_num is None: + return ["do_nothing"] + return ["network", "node", node_name, "network_interface", port_num, self.verb] + + +class NetworkPortEnableAction(NetworkPortAbstractAction): + """Action which enables a PORT.""" + + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) + self.verb: str = "enable" + + +class NetworkPortDisableAction(NetworkPortAbstractAction): + """Action which disables a PORT.""" + + def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None: + super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs) + self.verb: str = "disable" + + class ActionManager: """Class which manages the action space for an agent.""" @@ -602,6 +649,8 @@ class ActionManager: "NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction, "NETWORK_NIC_ENABLE": NetworkNICEnableAction, "NETWORK_NIC_DISABLE": NetworkNICDisableAction, + "NETWORK_PORT_ENABLE": NetworkPortEnableAction, + "NETWORK_PORT_DISABLE": NetworkPortDisableAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index cefcc429..e6f5aaee 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -45,7 +45,13 @@ "metadata": {}, "outputs": [], "source": [ - "from stable_baselines3 import PPO" + "from stable_baselines3 import PPO\n", + "\n", + "EPISODE_LEN = 128\n", + "NO_STEPS = EPISODE_LEN * 10\n", + "BATCH_SIZE = EPISODE_LEN * 10\n", + "TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n", + "LEARNING_RATE = 3e-4" ] }, { @@ -54,7 +60,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = PPO('MlpPolicy', gym)\n" + "model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n" ] }, { @@ -63,7 +69,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=10)\n" + "model.learn(total_timesteps=TOTAL_TIMESTEPS)\n" ] }, { @@ -72,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.save(\"deleteme\")" + "model.save(\"PrimAITE-v3.0.0b7-PPO\")" ] }, { @@ -80,7 +86,21 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "eval_model = PPO(\"MlpPolicy\", gym)\n", + "eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "\n", + "evaluate_policy(eval_model, gym, n_eval_episodes=10)" + ] } ], "metadata": { @@ -99,7 +119,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 0cad4124..38d20e1f 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -519,12 +519,10 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): """ super().enable() try: - pass self._connected_node.default_gateway_hello() - return True except AttributeError: pass - return False + return True @abstractmethod def receive_frame(self, frame: Frame) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index 3a9e2655..fbfd23f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -495,6 +495,8 @@ def game_and_agent(): {"type": "NETWORK_ACL_REMOVERULE", "options": {"target_router_hostname": "router"}}, {"type": "NETWORK_NIC_ENABLE"}, {"type": "NETWORK_NIC_DISABLE"}, + {"type": "NETWORK_PORT_ENABLE"}, + {"type": "NETWORK_PORT_DISABLE"}, ] action_space = ActionManager( @@ -507,6 +509,7 @@ def game_and_agent(): }, {"node_name": "server_1", "services": [{"service_name": "DNSServer"}]}, {"node_name": "server_2", "services": [{"service_name": "WebServer"}]}, + {"node_name": "router"}, ], max_folders_per_node=2, max_files_per_folder=2, diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 740fb491..5ced802c 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -312,3 +312,65 @@ def test_node_file_delete_integration(game_and_agent: Tuple[PrimaiteGame, ProxyA assert not client_1.file_system.get_file("downloads", "cat.png") # 3.1 (but with the reference to the original file, we can check that deleted flag is True ) assert file.deleted + + +def test_network_router_port_disable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the NetworkPortDisableAction can form a request and that it is accepted by the simulation.""" + game, agent = game_and_agent + + # 1: Check that client_1 can access the network + client_1 = game.simulation.network.get_node_by_hostname("client_1") + server_1 = game.simulation.network.get_node_by_hostname("server_1") + router = game.simulation.network.get_node_by_hostname("router") + + browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") + browser.run() + browser.target_url = "http://www.example.com" + assert browser.get_webpage() # check that the browser can access example.com before we block it + + # 2: Disable the NIC on client_1 + action = ( + "NETWORK_PORT_DISABLE", + { + "node_id": 3, # router + "port_id": 0, # port 1 + }, + ) + agent.store_action(action) + game.step() + + # 3: Check that the NIC is disabled, and that client 1 cannot access example.com + assert router.network_interface[1].enabled == False + assert not browser.get_webpage() + assert not client_1.ping("10.0.2.2") + assert not client_1.ping("10.0.2.3") + + # 4: check that servers can still communicate + assert server_1.ping("10.0.2.3") + + +def test_network_router_port_enable_integration(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): + """Test that the NetworkPortEnableAction can form a request and that it is accepted by the simulation.""" + + game, agent = game_and_agent + + # 1: Disable router port 1 + router = game.simulation.network.get_node_by_hostname("router") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + router.network_interface[1].disable() + assert not client_1.ping("10.0.2.2") + + # 2: Use action to enable port + action = ( + "NETWORK_PORT_ENABLE", + { + "node_id": 3, # router + "port_id": 0, # port 1 + }, + ) + agent.store_action(action) + game.step() + + # 3: Check that the Port is enabled, and that client 1 can ping again + assert router.network_interface[1].enabled == True + assert client_1.ping("10.0.2.3")