From 10a40538876930afa371bac5c77691626917472b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 1 Mar 2024 15:14:00 +0000 Subject: [PATCH] Fix tests --- docs/source/configuration/agents.rst | 6 +++--- .../_package_data/example_config_2_rl_agents.yaml | 2 +- src/primaite/session/session.py | 9 +++------ tests/assets/configs/dmz_network.yaml | 2 +- tests/e2e_integration_tests/test_primaite_session.py | 10 +++++----- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index f32843b1..ac67c365 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo ... - ref: green_agent_example team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: @@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive ``type`` -------- -Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour. +Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``probabilistic_agent`` generate their own behaviour. Available agent types: -- ``GreenWebBrowsingAgent`` +- ``probabilistic_agent`` - ``ProxyAgent`` - ``RedDatabaseCorruptingAgent`` diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index d6d3f044..b6b07afa 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -64,7 +64,7 @@ agents: - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index b8f80e95..d244f6b0 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict -from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv from primaite.session.io import SessionIO, SessionIOSettings @@ -40,7 +39,7 @@ class SessionMode(Enum): class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" - def __init__(self, game: PrimaiteGame): + def __init__(self, game_cfg: Dict): """Initialise PrimaiteSession object.""" self.training_options: TrainingOptions """Options specific to agent training.""" @@ -57,7 +56,7 @@ class PrimaiteSession: self.io_manager: Optional["SessionIO"] = None """IO manager for the session.""" - self.game: PrimaiteGame = game + self.game_cfg: Dict = game_cfg """Primaite Game object for managing main simulation loop and agents.""" def start_session(self) -> None: @@ -93,9 +92,7 @@ class PrimaiteSession: io_settings = cfg.get("io_settings", {}) io_manager = SessionIO(SessionIOSettings(**io_settings)) - game = PrimaiteGame.from_config(cfg) - - sess = cls(game=game) + sess = cls(game_cfg=cfg) sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 880735d9..56a68410 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -65,7 +65,7 @@ game: agents: - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: probabilistic_agent observation_space: type: UC2GreenObservation action_space: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 7785e4ae..da13dcd8 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -21,15 +21,15 @@ class TestPrimaiteSession: raise AssertionError assert session is not None - assert session.game.simulation - assert len(session.game.agents) == 3 - assert len(session.game.rl_agents) == 1 + assert session.env.game.simulation + assert len(session.env.game.agents) == 3 + assert len(session.env.game.rl_agents) == 1 assert session.policy assert session.env - assert session.game.simulation.network - assert len(session.game.simulation.network.nodes) == 10 + assert session.env.game.simulation.network + assert len(session.env.game.simulation.network.nodes) == 10 @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_start_session(self, temp_primaite_session):