Fix tests

This commit is contained in:
Marek Wolan
2024-03-01 15:14:00 +00:00
parent 9ff8adab1a
commit 10a4053887
5 changed files with 13 additions and 16 deletions

View File

@@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
...
- ref: green_agent_example
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:
@@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive
``type``
--------
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour.
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``probabilistic_agent`` generate their own behaviour.
Available agent types:
- ``GreenWebBrowsingAgent``
- ``probabilistic_agent``
- ``ProxyAgent``
- ``RedDatabaseCorruptingAgent``

View File

@@ -64,7 +64,7 @@ agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
@@ -40,7 +39,7 @@ class SessionMode(Enum):
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game: PrimaiteGame):
def __init__(self, game_cfg: Dict):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
@@ -57,7 +56,7 @@ class PrimaiteSession:
self.io_manager: Optional["SessionIO"] = None
"""IO manager for the session."""
self.game: PrimaiteGame = game
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
def start_session(self) -> None:
@@ -93,9 +92,7 @@ class PrimaiteSession:
io_settings = cfg.get("io_settings", {})
io_manager = SessionIO(SessionIOSettings(**io_settings))
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])

View File

@@ -65,7 +65,7 @@ game:
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
type: probabilistic_agent
observation_space:
type: UC2GreenObservation
action_space:

View File

@@ -21,15 +21,15 @@ class TestPrimaiteSession:
raise AssertionError
assert session is not None
assert session.game.simulation
assert len(session.game.agents) == 3
assert len(session.game.rl_agents) == 1
assert session.env.game.simulation
assert len(session.env.game.agents) == 3
assert len(session.env.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.game.simulation.network
assert len(session.game.simulation.network.nodes) == 10
assert session.env.game.simulation.network
assert len(session.env.game.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):