#2034 - Implemented the Simulation reset functionality by doing a deepcopy of the Simulation object inside the PrimaiteSession upon instantiation. Added a test that uninstalls a service before performing a reset then checks that the service reappears.

This commit is contained in:
Chris McCarthy
2023-11-23 15:31:06 +00:00
parent 4ee29efd1f
commit efeaa4c1cc
4 changed files with 35 additions and 11 deletions

View File

@@ -1,4 +1,5 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from copy import deepcopy
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
@@ -140,6 +141,9 @@ class PrimaiteSession:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -277,7 +281,7 @@ class PrimaiteSession:
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
self.simulation = self._simulation_initial_state
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
@@ -511,4 +515,6 @@ class PrimaiteSession:
if agent_load_path:
sess.policy.load(Path(agent_load_path))
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
return sess

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")