Fix tests
This commit is contained in:
@@ -19,7 +19,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
|
||||
...
|
||||
- ref: green_agent_example
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: probabilistic_agent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
@@ -57,11 +57,11 @@ Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive
|
||||
|
||||
``type``
|
||||
--------
|
||||
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour.
|
||||
Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``probabilistic_agent`` generate their own behaviour.
|
||||
|
||||
Available agent types:
|
||||
|
||||
- ``GreenWebBrowsingAgent``
|
||||
- ``probabilistic_agent``
|
||||
- ``ProxyAgent``
|
||||
- ``RedDatabaseCorruptingAgent``
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ agents:
|
||||
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: probabilistic_agent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
|
||||
@@ -40,7 +39,7 @@ class SessionMode(Enum):
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self, game: PrimaiteGame):
|
||||
def __init__(self, game_cfg: Dict):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
@@ -57,7 +56,7 @@ class PrimaiteSession:
|
||||
self.io_manager: Optional["SessionIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game: PrimaiteGame = game
|
||||
self.game_cfg: Dict = game_cfg
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
@@ -93,9 +92,7 @@ class PrimaiteSession:
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
io_manager = SessionIO(SessionIOSettings(**io_settings))
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
sess = cls(game=game)
|
||||
sess = cls(game_cfg=cfg)
|
||||
sess.io_manager = io_manager
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ game:
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
type: probabilistic_agent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
|
||||
@@ -21,15 +21,15 @@ class TestPrimaiteSession:
|
||||
raise AssertionError
|
||||
|
||||
assert session is not None
|
||||
assert session.game.simulation
|
||||
assert len(session.game.agents) == 3
|
||||
assert len(session.game.rl_agents) == 1
|
||||
assert session.env.game.simulation
|
||||
assert len(session.env.game.agents) == 3
|
||||
assert len(session.env.game.rl_agents) == 1
|
||||
|
||||
assert session.policy
|
||||
assert session.env
|
||||
|
||||
assert session.game.simulation.network
|
||||
assert len(session.game.simulation.network.nodes) == 10
|
||||
assert session.env.game.simulation.network
|
||||
assert len(session.env.game.simulation.network.nodes) == 10
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_start_session(self, temp_primaite_session):
|
||||
|
||||
Reference in New Issue
Block a user