Merge 'origin/dev' into feature/1971-ray-agents-2
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from copy import deepcopy
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -56,6 +57,9 @@ class PrimaiteGame:
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self._simulation_initial_state = deepcopy(self.simulation)
|
||||
"""The Simulation original state (deepcopy of the original Simulation)."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
|
||||
@@ -152,8 +156,8 @@ class PrimaiteGame:
|
||||
"""Reset the game, this will reset the simulation."""
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(self.episode_counter)
|
||||
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
|
||||
self.simulation = deepcopy(self._simulation_initial_state)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the game, this will close the simulation."""
|
||||
@@ -287,7 +291,7 @@ class PrimaiteGame:
|
||||
node_ref
|
||||
] = (
|
||||
new_node.uuid
|
||||
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
|
||||
) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
@@ -371,4 +375,6 @@ class PrimaiteGame:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
game._simulation_initial_state = deepcopy(game.simulation) # noqa
|
||||
|
||||
return game
|
||||
|
||||
@@ -100,8 +100,16 @@ class SoftwareManager:
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
for key, value in self.port_protocol_mapping.items():
|
||||
if value.name == software_name:
|
||||
self.port_protocol_mapping.pop(key)
|
||||
break
|
||||
for key, value in self._software_class_to_name_map.items():
|
||||
if value == software_name:
|
||||
self._software_class_to_name_map.pop(key)
|
||||
break
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
self.sys_log.info(f"Uninstalled {software_name}")
|
||||
return
|
||||
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
|
||||
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import datetime
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import nodeenv
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import TempPrimaiteSession
|
||||
|
||||
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = "tests/assets/configs/multi_agent_session.yaml"
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
class TestPrimaiteSession:
|
||||
@@ -73,3 +74,17 @@ class TestPrimaiteSession:
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_session_sim_reset(self, temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
client_1 = session.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.uninstall("DataManipulationBot")
|
||||
|
||||
assert "DataManipulationBot" not in client_1.software_manager.software
|
||||
|
||||
session.reset()
|
||||
client_1 = session.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
assert "DataManipulationBot" in client_1.software_manager.software
|
||||
|
||||
Reference in New Issue
Block a user