Merged PR 216: #2034 - Implemented the Simulation reset functionality

## Summary
Implemented the Simulation reset functionality by doing a deepcopy of the Simulation object inside the PrimaiteSession upon instantiation and assigning it to a `_simulation_initial_state` variable. Then when reset called, the sessions `simulation` variable is assigned the value of the `_simulation_initial_state` variable.

## Test process
Added a test that uninstalls a service before performing a reset then checks that the service reappears.

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [X] attended to any **TO-DOs** left in the code

#2034 - Implemented the Simulation reset functionality by doing a deepcopy of the Simulation object inside the PrimaiteSession upon instantiation. Added a test that uninstalls a service before performing a reset then checks that the service reappears.

Related work items: #2034
This commit is contained in:
Christopher McCarthy
2023-11-23 22:33:12 +00:00
4 changed files with 35 additions and 11 deletions

View File

@@ -1,4 +1,5 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from copy import deepcopy
from enum import Enum
from ipaddress import IPv4Address
from pathlib import Path
@@ -140,6 +141,9 @@ class PrimaiteSession:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -277,7 +281,7 @@ class PrimaiteSession:
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
self.simulation = deepcopy(self._simulation_initial_state)
def close(self) -> None:
"""Close the session, this will stop the env and close the simulation."""
@@ -511,4 +515,6 @@ class PrimaiteSession:
if agent_load_path:
sess.policy.load(Path(agent_load_path))
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
return sess

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")

View File

@@ -1,12 +1,7 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
import nodeenv
import pytest
import yaml

View File

@@ -1,12 +1,13 @@
import pydantic
import pytest
from tests import TEST_ASSETS_ROOT
from tests.conftest import TempPrimaiteSession
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
class TestPrimaiteSession:
@@ -66,3 +67,17 @@ class TestPrimaiteSession:
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.reset()
client_1 = session.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software