diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index bd7ed2cd..72ad01e7 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -160,16 +160,6 @@ class PrimaiteGame: return True return False - def reset(self) -> None: # TODO: deprecated - remove me - """Reset the game, this will reset the simulation.""" - self.episode_counter += 1 - self.step_counter = 0 - _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(episode=self.episode_counter) - for agent in self.agents: - agent.reward_function.total_reward = 0.0 - agent.reset_agent_for_episode() - def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index e21ce9eb..b9188bf0 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -160,8 +160,12 @@ class SimComponent(BaseModel): self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" + def setup_for_episode(self, episode: int): + """ + Perform any additional setup on this component that can't happen during __init__. + + For instance, some components may require for the entire network to exist before some configuration can be set. + """ pass def _init_request_manager(self) -> RequestManager: diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index 48205bbd..080a1004 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -45,12 +45,12 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.reset_component_for_episode(episode) + node.setup_for_episode(episode) for link in self.links.values(): - link.reset_component_for_episode(episode) + link.setup_for_episode(episode) for node in self.nodes.values(): node.power_on() @@ -171,7 +171,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.reset_component_for_episode() + link.setup_for_episode() def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 67ac42c8..e2a90db1 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -123,9 +123,9 @@ class NIC(SimComponent): _LOGGER.error(msg) raise ValueError(msg) - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) if episode and self.pcap: self.pcap.current_episode = episode self.pcap.setup_logger() @@ -1011,19 +1011,19 @@ class Node(SimComponent): self.session_manager.software_manager = self.software_manager self._install_system_software() - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) # Reset File System - self.file_system.reset_component_for_episode(episode) + self.file_system.setup_for_episode(episode) # Reset all Nics for nic in self.nics.values(): - nic.reset_component_for_episode(episode) + nic.setup_for_episode(episode) for software in self.software_manager.software.values(): - software.reset_component_for_episode(episode) + software.setup_for_episode(episode) if episode and self.sys_log: self.sys_log.current_episode = episode diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index aa154ad9..887bc9be 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -743,16 +743,16 @@ class Router(Node): self.arp.nics = self.nics self.icmp.arp = self.arp - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" self.arp.clear() - self.acl.reset_component_for_episode(episode) - self.route_table.reset_component_for_episode(episode) + self.acl.setup_for_episode(episode) + self.route_table.setup_for_episode(episode) for i, nic in self.ethernet_ports.items(): - nic.reset_component_for_episode(episode) + nic.setup_for_episode(episode) self.enable_port(i) - super().reset_component_for_episode(episode) + super().setup_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager() diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 18ed894c..bb6132a8 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -21,9 +21,9 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.reset_component_for_episode(episode) + self.network.setup_for_episode(episode) def _init_request_manager(self) -> RequestManager: rm = super()._init_request_manager()