diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index e6f5aaee..140df1b8 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -1,5 +1,21 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training an SB3 Agent\n", + "\n", + "This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### First, we import the inital packages and read in our configuration file." + ] + }, { "cell_type": "code", "execution_count": null, @@ -27,7 +43,14 @@ "outputs": [], "source": [ "with open(data_manipulation_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n" + " cfg = yaml.safe_load(f)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the given configuration, we generate the environment our agent will train in." ] }, { @@ -39,6 +62,13 @@ "gym = PrimaiteGymEnv(game_config=cfg)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lets define training parameters for the agent." + ] + }, { "cell_type": "code", "execution_count": null, @@ -48,9 +78,9 @@ "from stable_baselines3 import PPO\n", "\n", "EPISODE_LEN = 128\n", - "NO_STEPS = EPISODE_LEN * 10\n", - "BATCH_SIZE = EPISODE_LEN * 10\n", - "TOTAL_TIMESTEPS = 5e3 * EPISODE_LEN\n", + "NUM_EPISODES = 10\n", + "NO_STEPS = EPISODE_LEN * NUM_EPISODES\n", + "BATCH_SIZE = 32\n", "LEARNING_RATE = 3e-4" ] }, @@ -60,7 +90,14 @@ "metadata": {}, "outputs": [], "source": [ - "model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")\n" + "model = PPO('MlpPolicy', gym, learning_rate=LEARNING_RATE, n_steps=NO_STEPS, batch_size=BATCH_SIZE, verbose=0, tensorboard_log=\"./PPO_UC2/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the agent configured, let's train for our defined number of episodes." ] }, { @@ -69,7 +106,14 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=TOTAL_TIMESTEPS)\n" + "model.learn(total_timesteps=NO_STEPS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let's save the agent to a zip file that can be used in future evaluation." ] }, { @@ -78,7 +122,14 @@ "metadata": {}, "outputs": [], "source": [ - "model.save(\"PrimAITE-v3.0.0b7-PPO\")" + "model.save(\"PrimAITE-PPO-example-agent\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we load the saved agent and run it in evaluation mode." ] }, { @@ -88,7 +139,14 @@ "outputs": [], "source": [ "eval_model = PPO(\"MlpPolicy\", gym)\n", - "eval_model = PPO.load(\"PrimAITE-v3.0.0b7-PPO\", gym)" + "eval_model = PPO.load(\"PrimAITE-PPO-example-agent\", gym)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, evaluate the agent." ] }, { @@ -119,7 +177,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index f403176a..06ecd4be 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -6,7 +6,7 @@ "source": [ "# Build a simulation using the Python API\n", "\n", - "Currently, this notbook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n" + "Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n" ] }, { @@ -58,7 +58,8 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.network.hardware.base import Node\n" + "from primaite.simulator.network.hardware.nodes.host.computer import Computer\n", + "from primaite.simulator.network.hardware.nodes.host.server import Server" ] }, { @@ -67,9 +68,9 @@ "metadata": {}, "outputs": [], "source": [ - "my_pc = Node(hostname=\"primaite_pc\",)\n", + "my_pc = Computer(hostname=\"Computer\", ip_address=\"192.168.1.10\", subnet_mask=\"255.255.255.0\")\n", "net.add_node(my_pc)\n", - "my_server = Node(hostname=\"google_server\")\n", + "my_server = Server(hostname=\"Server\", ip_address=\"192.168.1.11\", subnet_mask=\"255.255.255.0\")\n", "net.add_node(my_server)\n" ] }, @@ -86,7 +87,8 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.network.hardware.base import NIC, Link, Switch\n" + "from primaite.simulator.network.hardware.nodes.host.host_node import NIC\n", + "from primaite.simulator.network.hardware.nodes.network.switch import Switch\n" ] }, { @@ -95,19 +97,17 @@ "metadata": {}, "outputs": [], "source": [ - "my_swtich = Switch(hostname=\"switch1\", num_ports=12)\n", - "net.add_node(my_swtich)\n", + "my_switch = Switch(hostname=\"switch1\", num_ports=12)\n", + "net.add_node(my_switch)\n", "\n", "pc_nic = NIC(ip_address=\"130.1.1.1\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n", "my_pc.connect_nic(pc_nic)\n", "\n", - "\n", "server_nic = NIC(ip_address=\"130.1.1.2\", gateway=\"130.1.1.255\", subnet_mask=\"255.255.255.0\")\n", "my_server.connect_nic(server_nic)\n", "\n", - "\n", - "net.connect(pc_nic, my_swtich.switch_ports[1])\n", - "net.connect(server_nic, my_swtich.switch_ports[2])\n" + "net.connect(pc_nic, my_switch.network_interface[1])\n", + "net.connect(server_nic, my_switch.network_interface[2])\n" ] }, { @@ -124,7 +124,8 @@ "outputs": [], "source": [ "from primaite.simulator.file_system.file_type import FileType\n", - "from primaite.simulator.file_system.file_system import File" + "from primaite.simulator.file_system.file_system import File\n", + "from primaite.simulator.system.core.sys_log import SysLog" ] }, { @@ -134,7 +135,7 @@ "outputs": [], "source": [ "my_pc_downloads_folder = my_pc.file_system.create_folder(\"downloads\")\n", - "my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",file_type=FileType.ZIP))" + "my_pc_downloads_folder.add_file(File(name=\"firefox_installer.zip\",folder_id=\"Test\", folder_name=\"downloads\" ,file_type=FileType.ZIP, sys_log=SysLog(hostname=\"Test\")))" ] }, { @@ -160,9 +161,12 @@ "metadata": {}, "outputs": [], "source": [ + "from pathlib import Path\n", "from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n", "from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n", "from primaite.simulator.network.transmission.transport_layer import Port\n", + "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", + "from primaite.simulator.file_system.file_system import FileSystem\n", "\n", "# no applications exist yet so we will create our own.\n", "class MSPaint(Application):\n", @@ -176,7 +180,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, ports={Port.HTTP}, operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual')" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port.HTTP, protocol = IPProtocol.NONE,operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { @@ -257,7 +261,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 1da58409..7f4cf3b1 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -238,7 +238,7 @@ "id": "20", "metadata": {}, "source": [ - "Calling `switch.show()` displays the Switch orts on the Switch." + "Calling `switch.show()` displays the Switch ports on the Switch." ] }, { @@ -256,11 +256,9 @@ { "cell_type": "markdown", "id": "22", - "metadata": { - "tags": [] - }, + "metadata": {}, "source": [ - "Calling `switch.arp.show()` displays the Switch ARP Cache." + "Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=`." ] }, { @@ -271,33 +269,13 @@ "tags": [] }, "outputs": [], - "source": [ - "network.get_node_by_hostname(\"switch_1\").arp.show()" - ] - }, - { - "cell_type": "markdown", - "id": "24", - "metadata": {}, - "source": [ - "Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": { - "tags": [] - }, - "outputs": [], "source": [ "network.get_node_by_hostname(\"switch_1\").sys_log.show()" ] }, { "cell_type": "markdown", - "id": "26", + "id": "24", "metadata": {}, "source": [ "### Computer/Server Nodes\n", @@ -307,7 +285,7 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "25", "metadata": { "tags": [] }, @@ -315,6 +293,26 @@ "Calling `computer.show()` displays the NICs on the Computer/Server." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "network.get_node_by_hostname(\"security_suite\").show()" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, + "source": [ + "Calling `computer.arp.show()` displays the Computer/Server ARP Cache." + ] + }, { "cell_type": "code", "execution_count": null, @@ -324,7 +322,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"security_suite\").show()" + "network.get_node_by_hostname(\"security_suite\").arp.show()" ] }, { @@ -332,7 +330,7 @@ "id": "29", "metadata": {}, "source": [ - "Calling `computer.arp.show()` displays the Computer/Server ARP Cache." + "Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=`." ] }, { @@ -344,7 +342,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"security_suite\").arp.show()" + "network.get_node_by_hostname(\"security_suite\").sys_log.show()" ] }, { @@ -352,7 +350,9 @@ "id": "31", "metadata": {}, "source": [ - "Calling `switch.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=`." + "## Basic Network Comms Check\n", + "\n", + "We can perform a good old ping to check that Nodes are able to communicate with each other." ] }, { @@ -364,7 +364,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"security_suite\").sys_log.show()" + "network.show(nodes=False, links=False)" ] }, { @@ -372,9 +372,7 @@ "id": "33", "metadata": {}, "source": [ - "## Basic Network Comms Check\n", - "\n", - "We can perform a good old ping to check that Nodes are able to communicate with each other." + "We'll first ping client_1's default gateway." ] }, { @@ -386,27 +384,27 @@ }, "outputs": [], "source": [ - "network.show(nodes=False, links=False)" - ] - }, - { - "cell_type": "markdown", - "id": "35", - "metadata": {}, - "source": [ - "We'll first ping client_1's default gateway." + "network.get_node_by_hostname(\"client_1\").ping(\"192.168.10.1\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "36", + "id": "35", "metadata": { "tags": [] }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_1\").ping(\"192.168.10.1\")" + "network.get_node_by_hostname(\"client_1\").sys_log.show(15)" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, + "source": [ + "Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)." ] }, { @@ -418,7 +416,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_1\").sys_log.show(15)" + "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.1\")" ] }, { @@ -426,7 +424,7 @@ "id": "38", "metadata": {}, "source": [ - "Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)." + "And finally, we'll ping the web server." ] }, { @@ -438,7 +436,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.1\")" + "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")" ] }, { @@ -446,7 +444,7 @@ "id": "40", "metadata": {}, "source": [ - "And finally, we'll ping the web server." + "To confirm that the ping was received and processed by the web_server, we can view the sys log" ] }, { @@ -458,7 +456,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")" + "network.get_node_by_hostname(\"web_server\").sys_log.show()" ] }, { @@ -466,29 +464,29 @@ "id": "42", "metadata": {}, "source": [ - "To confirm that the ping was received and processed by the web_server, we can view the sys log" + "## Advanced Network Usage\n", + "\n", + "We can now use the Network to perform some more advanced things." + ] + }, + { + "cell_type": "markdown", + "id": "43", + "metadata": {}, + "source": [ + "Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..." ] }, { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "44", "metadata": { "tags": [] }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"web_server\").sys_log.show()" - ] - }, - { - "cell_type": "markdown", - "id": "44", - "metadata": {}, - "source": [ - "## Advanced Network Usage\n", - "\n", - "We can now use the Network to perform some more advaced things." + "network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")" ] }, { @@ -496,7 +494,7 @@ "id": "45", "metadata": {}, "source": [ - "Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..." + "If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:" ] }, { @@ -508,7 +506,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")" + "network.get_node_by_hostname(\"client_2\").sys_log.show()" ] }, { @@ -516,7 +514,7 @@ "id": "47", "metadata": {}, "source": [ - "If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:" + "Now we'll add an ACL to block ICMP from 192.168.10.22" ] }, { @@ -527,30 +525,10 @@ "tags": [] }, "outputs": [], - "source": [ - "network.get_node_by_hostname(\"client_2\").sys_log.show()" - ] - }, - { - "cell_type": "markdown", - "id": "49", - "metadata": {}, - "source": [ - "Now we'll add an ACL to block ICMP from 192.168.10.22" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "50", - "metadata": { - "tags": [] - }, - "outputs": [], "source": [ "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.network.transmission.transport_layer import Port\n", - "from primaite.simulator.network.hardware.nodes.router import ACLAction\n", + "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", " protocol=IPProtocol.ICMP,\n", @@ -562,7 +540,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51", + "id": "49", "metadata": { "tags": [] }, @@ -573,12 +551,32 @@ }, { "cell_type": "markdown", - "id": "52", + "id": "50", "metadata": {}, "source": [ "Now we attempt (and fail) to ping the web server" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")" + ] + }, + { + "cell_type": "markdown", + "id": "52", + "metadata": {}, + "source": [ + "We can check that the ping was actually sent by client_2 by viewing the sys log" + ] + }, { "cell_type": "code", "execution_count": null, @@ -588,7 +586,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")" + "network.get_node_by_hostname(\"client_2\").sys_log.show()" ] }, { @@ -596,7 +594,7 @@ "id": "54", "metadata": {}, "source": [ - "We can check that the ping was actually sent by client_2 by viewing the sys log" + "We can check the router sys log to see why the traffic was blocked" ] }, { @@ -608,7 +606,7 @@ }, "outputs": [], "source": [ - "network.get_node_by_hostname(\"client_2\").sys_log.show()" + "network.get_node_by_hostname(\"router_1\").sys_log.show()" ] }, { @@ -616,7 +614,7 @@ "id": "56", "metadata": {}, "source": [ - "We can check the router sys log to see why the traffic was blocked" + "Now a final check to ensure that client_1 can still ping the web_server." ] }, { @@ -627,26 +625,6 @@ "tags": [] }, "outputs": [], - "source": [ - "network.get_node_by_hostname(\"router_1\").sys_log.show()" - ] - }, - { - "cell_type": "markdown", - "id": "58", - "metadata": {}, - "source": [ - "Now a final check to ensure that client_1 can still ping the web_server." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "59", - "metadata": { - "tags": [] - }, - "outputs": [], "source": [ "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")" ] @@ -654,7 +632,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60", + "id": "58", "metadata": { "tags": [] }, diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 31378689..8928e8ef 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -316,6 +316,16 @@ class HostNode(Node): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=ip_address, subnet_mask=subnet_mask)) + @property + def arp(self) -> Optional[ARP]: + """ + Return the ARP Cache of the HostNode. + + :return: ARP Cache for given HostNode + :rtype: Optional[ARP] + """ + return self.software_manager.software.get("ARP") + def _install_system_software(self): """ Installs the system software and network services typically found on an operating system. diff --git a/src/primaite/simulator/network/hardware/nodes/network/network_node.py b/src/primaite/simulator/network/hardware/nodes/network/network_node.py index ebdb6ed8..0474ca08 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/network_node.py +++ b/src/primaite/simulator/network/hardware/nodes/network/network_node.py @@ -1,7 +1,9 @@ from abc import abstractmethod +from typing import Optional from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.network.transmission.data_link_layer import Frame +from primaite.simulator.system.services.arp.arp import ARP class NetworkNode(Node): @@ -28,3 +30,13 @@ class NetworkNode(Node): :type from_network_interface: NetworkInterface """ pass + + @property + def arp(self) -> Optional[ARP]: + """ + Return the ARP Cache of the NetworkNode. + + :return: ARP Cache for given NetworkNode + :rtype: Optional[ARP] + """ + return self.software_manager.software.get("ARP") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 5d041fd1..ce188838 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1237,8 +1237,7 @@ class Router(NetworkNode): icmp: RouterICMP = self.software_manager.icmp # noqa icmp.router = self self.software_manager.install(RouterARP) - arp: RouterARP = self.software_manager.arp # noqa - arp.router = self + self.arp.router = self def _set_default_acl(self): """