diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index fa17b94b..e96b9a42 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,4 +1,5 @@ """PrimAITE game - Encapsulates the simulation and agents.""" +from copy import deepcopy from ipaddress import IPv4Address from typing import Dict, List @@ -56,6 +57,9 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" + self._simulation_initial_state = deepcopy(self.simulation) + """The Simulation original state (deepcopy of the original Simulation).""" + self.agents: List[AbstractAgent] = [] """List of agents.""" @@ -152,8 +156,8 @@ class PrimaiteGame: """Reset the game, this will reset the simulation.""" self.episode_counter += 1 self.step_counter = 0 - _LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(self.episode_counter) + _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") + self.simulation = deepcopy(self._simulation_initial_state) def close(self) -> None: """Close the game, this will close the simulation.""" @@ -287,7 +291,7 @@ class PrimaiteGame: node_ref ] = ( new_node.uuid - ) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object + ) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object # 2. create links between nodes for link_cfg in links_cfg: @@ -371,4 +375,6 @@ class PrimaiteGame: else: print("agent type not found") + game._simulation_initial_state = deepcopy(game.simulation) # noqa + return game diff --git a/src/primaite/simulator/system/core/software_manager.py b/src/primaite/simulator/system/core/software_manager.py index 8b8fe599..21a121c1 100644 --- a/src/primaite/simulator/system/core/software_manager.py +++ b/src/primaite/simulator/system/core/software_manager.py @@ -100,8 +100,16 @@ class SoftwareManager: self.node.uninstall_application(software) elif isinstance(software, Service): self.node.uninstall_service(software) + for key, value in self.port_protocol_mapping.items(): + if value.name == software_name: + self.port_protocol_mapping.pop(key) + break + for key, value in self._software_class_to_name_map.items(): + if value == software_name: + self._software_class_to_name_map.pop(key) + break del software - self.sys_log.info(f"Deleted {software_name}") + self.sys_log.info(f"Uninstalled {software_name}") return self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed") diff --git a/tests/conftest.py b/tests/conftest.py index 24202350..c11170b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,7 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK -import datetime -import shutil -import tempfile -from datetime import datetime from pathlib import Path from typing import Any, Dict, Union -import nodeenv import pytest import yaml diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 25b8998b..17d8a4d1 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -1,13 +1,14 @@ import pydantic import pytest +from tests import TEST_ASSETS_ROOT from tests.conftest import TempPrimaiteSession -CFG_PATH = "tests/assets/configs/test_primaite_session.yaml" -TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml" -EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml" -MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml" -MULTI_AGENT_PATH = "tests/assets/configs/multi_agent_session.yaml" +CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml" +TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml" +EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml" +MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml" +MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" class TestPrimaiteSession: @@ -73,3 +74,17 @@ class TestPrimaiteSession: def test_error_thrown_on_bad_configuration(self): with pytest.raises(pydantic.ValidationError): session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH) + + @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) + def test_session_sim_reset(self, temp_primaite_session): + with temp_primaite_session as session: + session: TempPrimaiteSession + client_1 = session.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.uninstall("DataManipulationBot") + + assert "DataManipulationBot" not in client_1.software_manager.software + + session.reset() + client_1 = session.simulation.network.get_node_by_hostname("client_1") + + assert "DataManipulationBot" in client_1.software_manager.software