From efeaa4c1cc13edd9c128e4dc47aa52d5ec586cde Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 23 Nov 2023 15:31:06 +0000 Subject: [PATCH] #2034 - Implemented the Simulation reset functionality by doing a deepcopy of the Simulation object inside the PrimaiteSession upon instantiation. Added a test that uninstalls a service before performing a reset then checks that the service reappears. --- src/primaite/game/session.py | 8 ++++++- .../simulator/system/core/software_manager.py | 10 +++++++- tests/conftest.py | 5 ---- .../test_primaite_session.py | 23 +++++++++++++++---- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index ad0537e8..572dbecb 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -1,4 +1,5 @@ """PrimAITE session - the main entry point to training agents on PrimAITE.""" +from copy import deepcopy from enum import Enum from ipaddress import IPv4Address from pathlib import Path @@ -140,6 +141,9 @@ class PrimaiteSession: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + self._simulation_initial_state = deepcopy(self.simulation) + """The Simulation original state (deepcopy of the original Simulation).""" + self.agents: List[AbstractAgent] = [] """List of agents.""" @@ -277,7 +281,7 @@ class PrimaiteSession: self.episode_counter += 1 self.step_counter = 0 _LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(self.episode_counter) + self.simulation = self._simulation_initial_state def close(self) -> None: """Close the session, this will stop the env and close the simulation.""" @@ -511,4 +515,6 @@ class PrimaiteSession: if agent_load_path: sess.policy.load(Path(agent_load_path)) + sess._simulation_initial_state = deepcopy(sess.simulation) # noqa + return sess diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 8b8fe599..21a121c1 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -100,8 +100,16 @@ class SoftwareManager: self.node.uninstall_application(software) elif isinstance(software, Service): self.node.uninstall_service(software) + for key, value in self.port_protocol_mapping.items(): + if value.name == software_name: + self.port_protocol_mapping.pop(key) + break + for key, value in self._software_class_to_name_map.items(): + if value == software_name: + self._software_class_to_name_map.pop(key) + break del software - self.sys_log.info(f"Deleted {software_name}") + self.sys_log.info(f"Uninstalled {software_name}") return self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") diff --git a/tests/conftest.py b/tests/conftest.py index 6a65b12f..419a6128 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,7 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -import datetime -import shutil -import tempfile -from datetime import datetime from pathlib import Path from typing import Any, Dict, Union -import nodeenv import pytest import yaml diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index b6122bad..6839a190 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -1,12 +1,13 @@ import pydantic import pytest +from tests import TEST_ASSETS_ROOT from tests.conftest import TempPrimaiteSession -CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" -TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" -EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" -MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml" +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" +TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml" +EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml" +MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml" class TestPrimaiteSession: @@ -66,3 +67,17 @@ class TestPrimaiteSession: def test_error_thrown_on_bad_configuration(self): with pytest.raises(pydantic.ValidationError): session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH) + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_session_sim_reset(self, temp_primaite_session): + with temp_primaite_session as session: + session: TempPrimaiteSession + client_1 = session.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.uninstall("DataManipulationBot") + + assert "DataManipulationBot" not in client_1.software_manager.software + + session.reset() + client_1 = session.simulation.network.get_node_by_hostname("client_1") + + assert "DataManipulationBot" in client_1.software_manager.software