diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 1bf2ea2f..667cf2bf 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -354,6 +354,11 @@ class RouteEntry(SimComponent): kwargs[key] = IPv4Address(kwargs[key]) super().__init__(**kwargs) + def set_original_state(self): + """Sets the original state.""" + vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} + self._original_values = self.model_dump(include=vals_to_include) + def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -385,6 +390,18 @@ class RouteTable(SimComponent): routes: List[RouteEntry] = [] sys_log: SysLog + def set_original_state(self): + """Sets the original state.""" + """Sets the original state.""" + super().set_original_state() + self._original_state["routes_orig"] = self.routes + + def reset_component_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + self.routes.clear() + self.routes = self._original_state["routes_orig"] + super().reset_component_for_episode(episode) + def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -660,13 +677,15 @@ class Router(Node): def set_original_state(self): """Sets the original state.""" self.acl.set_original_state() - vals_to_include = {"num_ports", "route_table"} + self.route_table.set_original_state() + vals_to_include = {"num_ports"} self._original_state = self.model_dump(include=vals_to_include) def reset_component_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() self.acl.reset_component_for_episode(episode) + self.route_table.reset_component_for_episode(episode) for i, nic in self.ethernet_ports.items(): nic.reset_component_for_episode(episode) self.enable_port(i) @@ -765,6 +784,7 @@ class Router(Node): dst_ip_address=dst_ip_address, dst_port=dst_port, ) + if not permitted: at_port = self._get_port_of_nic(from_nic) self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}")