Merge 'origin/dev' into feature/1971-ray-agents-2

This commit is contained in:
Marek Wolan
2023-11-24 09:50:37 +00:00
4 changed files with 38 additions and 14 deletions

View File

@@ -1,4 +1,5 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from copy import deepcopy
from ipaddress import IPv4Address
from typing import Dict, List
@@ -56,6 +57,9 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -152,8 +156,8 @@ class PrimaiteGame:
"""Reset the game, this will reset the simulation."""
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite game, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
def close(self) -> None:
"""Close the game, this will close the simulation."""
@@ -287,7 +291,7 @@ class PrimaiteGame:
node_ref
] = (
new_node.uuid
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object
# 2. create links between nodes
for link_cfg in links_cfg:
@@ -371,4 +375,6 @@ class PrimaiteGame:
else:
print("agent type not found")
game._simulation_initial_state = deepcopy(game.simulation) # noqa
return game

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")

View File

@@ -1,12 +1,7 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
import nodeenv
import pytest
import yaml

View File

@@ -1,13 +1,14 @@
import pydantic
import pytest
from tests import TEST_ASSETS_ROOT
from tests.conftest import TempPrimaiteSession
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = "tests/assets/configs/multi_agent_session.yaml"
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
class TestPrimaiteSession:
@@ -73,3 +74,17 @@ class TestPrimaiteSession:
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.reset()
client_1 = session.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software