Files
PrimAITE/tests/e2e_integration_tests/test_environment.py
2024-05-29 14:54:45 +01:00

93 lines
3.8 KiB
Python

import pydantic
import pytest
import yaml
from gymnasium.core import ObsType
from numpy import ndarray
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayMARLEnv
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
class TestPrimaiteEnvironment:
def test_creating_env(self):
"""Check that environment loads correctly from config and it can be reset."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
def env_checks():
assert env is not None
assert env.game.simulation
assert len(env.game.agents) == 3
assert len(env.game.rl_agents) == 1
assert env.game.simulation.network
assert len(env.game.simulation.network.nodes) == 12
wireless = env.game.simulation.network.get_node_by_hostname("router_2")
assert isinstance(wireless, WirelessRouter)
printer = env.game.simulation.network.get_node_by_hostname("HP_LaserJet_Pro_4102fdn_printer")
assert isinstance(printer, Printer)
env_checks()
env.reset()
env_checks()
def test_step_env(self):
"""Make sure you can go all the way through the session without errors."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
# run every action and make sure there's no crash
for act in range(num_actions):
env.step(act)
# try running action number outside the action map to check that it fails.
with pytest.raises(KeyError):
env.step(num_actions)
obs, rew, trunc, term, info = env.step(0)
assert isinstance(obs, ndarray)
def test_multi_agent_env(self):
"""Check that we can run a training session with a multi agent system."""
with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteRayMARLEnv(env_config=cfg)
assert set(env._agent_ids) == {"defender1", "defender2"}
assert len(env.agents) == 2
defender1 = env.agents["defender1"]
defender2 = env.agents["defender2"]
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
# ensure we can run all valid actions without error
for act_1 in range(num_actions_1):
env.step({"defender1": act_1, "defender2": 0})
for act_2 in range(num_actions_2):
env.step({"defender1": 0, "defender2": act_2})
# ensure we get error when taking an invalid action
with pytest.raises(KeyError):
env.step({"defender1": num_actions_1, "defender2": 0})
with pytest.raises(KeyError):
env.step({"defender1": 0, "defender2": num_actions_2})
def test_error_thrown_on_bad_configuration(self):
"""Make sure we throw an error when the config is bad."""
with open(MISCONFIGURED_PATH, "r") as f:
cfg = yaml.safe_load(f)
with pytest.raises(pydantic.ValidationError):
env = PrimaiteGymEnv(env_config=cfg)