#2034 - Implemented the Simulation reset functionality by doing a deepcopy of the Simulation object inside the PrimaiteSession upon instantiation. Added a test that uninstalls a service before performing a reset then checks that the service reappears.

This commit is contained in:
Chris McCarthy
2023-11-23 15:31:06 +00:00
parent 4ee29efd1f
commit efeaa4c1cc
4 changed files with 35 additions and 11 deletions

View File

@@ -1,4 +1,5 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from copy import deepcopy
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
@@ -140,6 +141,9 @@ class PrimaiteSession:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -277,7 +281,7 @@ class PrimaiteSession:
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
self.simulation = self._simulation_initial_state
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
@@ -511,4 +515,6 @@ class PrimaiteSession:
if agent_load_path:
sess.policy.load(Path(agent_load_path))
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
return sess

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")

View File

@@ -1,12 +1,7 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
import nodeenv
import pytest
import yaml

View File

@@ -1,12 +1,13 @@
import pydantic
import pytest
from tests import TEST_ASSETS_ROOT
from tests.conftest import TempPrimaiteSession
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
class TestPrimaiteSession:
@@ -66,3 +67,17 @@ class TestPrimaiteSession:
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.reset()
client_1 = session.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software