#2034 - Implemented the Simulation reset functionality by doing a deepcopy of the Simulation object inside the PrimaiteSession upon instantiation. Added a test that uninstalls a service before performing a reset then checks that the service reappears.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from pathlib import Path
|
||||
@@ -140,6 +141,9 @@ class PrimaiteSession:
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self._simulation_initial_state = deepcopy(self.simulation)
|
||||
"""The Simulation original state (deepcopy of the original Simulation)."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
|
||||
@@ -277,7 +281,7 @@ class PrimaiteSession:
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(self.episode_counter)
|
||||
self.simulation = self._simulation_initial_state
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
@@ -511,4 +515,6 @@ class PrimaiteSession:
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
|
||||
|
||||
return sess
|
||||
|
||||
@@ -100,8 +100,16 @@ class SoftwareManager:
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
for key, value in self.port_protocol_mapping.items():
|
||||
if value.name == software_name:
|
||||
self.port_protocol_mapping.pop(key)
|
||||
break
|
||||
for key, value in self._software_class_to_name_map.items():
|
||||
if value == software_name:
|
||||
self._software_class_to_name_map.pop(key)
|
||||
break
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
self.sys_log.info(f"Uninstalled {software_name}")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import datetime
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import nodeenv
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import TempPrimaiteSession
|
||||
|
||||
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
||||
|
||||
|
||||
class TestPrimaiteSession:
|
||||
@@ -66,3 +67,17 @@ class TestPrimaiteSession:
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_session_sim_reset(self, temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
client_1 = session.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.uninstall("DataManipulationBot")
|
||||
|
||||
assert "DataManipulationBot" not in client_1.software_manager.software
|
||||
|
||||
session.reset()
|
||||
client_1 = session.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
assert "DataManipulationBot" in client_1.software_manager.software
|
||||
|
||||
Reference in New Issue
Block a user