diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f2da23..227cf729 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Made observation space flattening optional (on by default). To turn off for an agent, change the `agent_settings.flatten_obs` setting in the config. - Added support for SQL INSERT command. - Added ability to log each agent's action choices in each step to a JSON file. +- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present. ### Bug Fixes diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst index ab51d7fd..b19574f4 100644 --- a/docs/source/configuration/simulation.rst +++ b/docs/source/configuration/simulation.rst @@ -64,10 +64,12 @@ this results in: endpoint_a_port: 1 # port 1 on computer_1 endpoint_b_hostname: switch endpoint_b_port: 1 # port 1 on switch + bandwidth: 100 - endpoint_a_hostname: computer_2 endpoint_a_port: 1 # port 1 on computer_2 endpoint_b_hostname: switch endpoint_b_port: 2 # port 2 on switch + bandwidth: 100 ``ref`` ^^^^^^^ @@ -95,3 +97,7 @@ The ``hostname`` of the node which must be connected. The port on ``endpoint_b_hostname`` which is to be connected to ``endpoint_a_port``. This accepts an integer value e.g. if port 1 is to be connected, the configuration should be ``endpoint_b_port: 1`` + +``bandwidth`` + +This is an integer value specifying the allowed bandwidth across the connection. Units are in Mbps. diff --git a/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml b/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml index 8b97c6df..09e85d03 100644 --- a/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml +++ b/src/primaite/config/_package_data/multi_lan_internet_network_example.yaml @@ -368,6 +368,7 @@ simulation: endpoint_a_port: 1 endpoint_b_hostname: switch_1 endpoint_b_port: 1 + bandwidth: 200 - endpoint_a_hostname: pc_2 endpoint_a_port: 1 endpoint_b_hostname: switch_1 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ab68ea2d..ea5b3831 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -409,6 +409,7 @@ class PrimaiteGame: for link_cfg in links_cfg: node_a = net.get_node_by_hostname(link_cfg["endpoint_a_hostname"]) node_b = net.get_node_by_hostname(link_cfg["endpoint_b_hostname"]) + bandwidth = link_cfg.get("bandwidth", 100) # default value if not configured if isinstance(node_a, Switch): endpoint_a = node_a.network_interface[link_cfg["endpoint_a_port"]] @@ -418,7 +419,7 @@ class PrimaiteGame: endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]] else: endpoint_b = node_b.network_interface[link_cfg["endpoint_b_port"]] - net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b) + net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth) # 3. create agents agents_cfg = cfg.get("agents", []) diff --git a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index 4038c0c2..1fb66405 100644 --- a/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -97,7 +97,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/session/episode_schedule.py b/src/primaite/session/episode_schedule.py index fa010d27..c009fa09 100644 --- a/src/primaite/session/episode_schedule.py +++ b/src/primaite/session/episode_schedule.py @@ -57,7 +57,7 @@ class EpisodeListScheduler(EpisodeScheduler): if episode_num >= len(self.schedule): if not self._exceeded_episode_list: self._exceeded_episode_list = True - _LOGGER.warn( + _LOGGER.warning( f"Running episode {episode_num} but the schedule only defines " f"{len(self.schedule)} episodes. Looping back to the beginning" ) diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 17308c97..91ea3c71 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -309,7 +309,9 @@ class Network(SimComponent): self._node_request_manager.remove_request(name=node.hostname) _LOGGER.info(f"Removed node {node.hostname} from network {self.uuid}") - def connect(self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, **kwargs) -> Optional[Link]: + def connect( + self, endpoint_a: WiredNetworkInterface, endpoint_b: WiredNetworkInterface, bandwidth: int = 100, **kwargs + ) -> Optional[Link]: """ Connect two endpoints on the network by creating a link between their NICs/SwitchPorts. @@ -319,6 +321,8 @@ class Network(SimComponent): :type endpoint_a: WiredNetworkInterface :param endpoint_b: The second endpoint to connect. :type endpoint_b: WiredNetworkInterface + :param bandwidth: bandwidth of new link, default of 100mbps + :type bandwidth: int :raises RuntimeError: If any validation or runtime checks fail. """ node_a: Node = endpoint_a.parent @@ -330,7 +334,7 @@ class Network(SimComponent): if node_a is node_b: _LOGGER.warning(f"Cannot link endpoint {endpoint_a} to {endpoint_b} because they belong to the same node.") return - link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, **kwargs) + link = Link(endpoint_a=endpoint_a, endpoint_b=endpoint_b, bandwidth=bandwidth, **kwargs) self.links[link.uuid] = link self._link_id_map[len(self.links)] = link self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname) diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 8bda626a..f4475bec 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -50,6 +50,7 @@ def create_office_lan( num_pcs: int, network: Optional[Network] = None, include_router: bool = True, + bandwidth: int = 100, ) -> Network: """ Creates a 2-Tier or 3-Tier office local area network (LAN). @@ -109,9 +110,11 @@ def create_office_lan( switch.power_on() network.add_node(switch) if num_of_switches > 1: - network.connect(core_switch.network_interface[core_switch_port], switch.network_interface[24]) + network.connect( + core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth + ) else: - network.connect(router.network_interface[1], switch.network_interface[24]) + network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth) # Add PCs to the LAN and connect them to switches for i in range(1, num_pcs + 1): @@ -125,9 +128,11 @@ def create_office_lan( # Connect the new switch to the router or core switch if num_of_switches > 1: core_switch_port += 1 - network.connect(core_switch.network_interface[core_switch_port], switch.network_interface[24]) + network.connect( + core_switch.network_interface[core_switch_port], switch.network_interface[24], bandwidth=bandwidth + ) else: - network.connect(router.network_interface[1], switch.network_interface[24]) + network.connect(router.network_interface[1], switch.network_interface[24], bandwidth=bandwidth) # Create and add a PC to the network pc = Computer( @@ -142,7 +147,7 @@ def create_office_lan( # Connect the PC to the switch switch_port += 1 - network.connect(switch.network_interface[switch_port], pc.network_interface[1]) + network.connect(switch.network_interface[switch_port], pc.network_interface[1], bandwidth=bandwidth) switch.network_interface[switch_port].enable() return network diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index b4d32dc4..a515ce58 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -547,15 +547,15 @@ class Link(SimComponent): :param endpoint_a: The first NIC or SwitchPort connected to the Link. :param endpoint_b: The second NIC or SwitchPort connected to the Link. - :param bandwidth: The bandwidth of the Link in Mbps (default is 100 Mbps). + :param bandwidth: The bandwidth of the Link in Mbps. """ endpoint_a: WiredNetworkInterface "The first WiredNetworkInterface connected to the Link." endpoint_b: WiredNetworkInterface "The second WiredNetworkInterface connected to the Link." - bandwidth: float = 100.0 - "The bandwidth of the Link in Mbps (default is 100 Mbps)." + bandwidth: float + "The bandwidth of the Link in Mbps." current_load: float = 0.0 "The current load on the link in Mbps." diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 37505f6e..0cbaefdb 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -234,7 +234,9 @@ simulation: endpoint_a_port: 1 endpoint_b_hostname: client_1 endpoint_b_port: 1 + bandwidth: 200 - endpoint_a_hostname: switch_1 endpoint_a_port: 2 endpoint_b_hostname: client_2 endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py index 174bd0c0..b71c0c9c 100644 --- a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -43,3 +43,6 @@ def test_basic_config(): # client 3 should not be online client_3: Computer = network.get_node_by_hostname("client_3") assert client_3.operating_state == NodeOperatingState.OFF + + for link in network.links: + assert network.links[link].bandwidth == 200 diff --git a/tests/integration_tests/system/test_web_client_server_and_database.py b/tests/integration_tests/system/test_web_client_server_and_database.py index 56f07634..0d1bb584 100644 --- a/tests/integration_tests/system/test_web_client_server_and_database.py +++ b/tests/integration_tests/system/test_web_client_server_and_database.py @@ -49,12 +49,12 @@ def web_client_web_server_database(example_network) -> Tuple[Network, Computer, db_server_nic = db_server.network_interfaces[next(iter(db_server.network_interfaces))] # Connect Computer and Server - link_computer_server = Link(endpoint_a=computer_nic, endpoint_b=server_nic) + link_computer_server = Link(endpoint_a=computer_nic, endpoint_b=server_nic, bandwidth=100) # Should be linked assert link_computer_server.is_up # Connect database server and web server - link_server_db = Link(endpoint_a=server_nic, endpoint_b=db_server_nic) + link_server_db = Link(endpoint_a=server_nic, endpoint_b=db_server_nic, bandwidth=100) # Should be linked assert link_computer_server.is_up assert link_server_db.is_up