#2588 optional RL deps
This commit is contained in:
@@ -26,7 +26,7 @@ jobs:
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
- script: |
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
displayName: 'Install PrimAITE for docs autosummary'
|
||||
|
||||
- script: |
|
||||
|
||||
@@ -82,12 +82,12 @@ stages:
|
||||
|
||||
- script: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
python -m pip install $PRIMAITE_WHEEL[dev,rl]
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev,rl]"
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
|
||||
2
.github/workflows/build-sphinx.yml
vendored
2
.github/workflows/build-sphinx.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
- name: Install PrimAITE for docs autosummary
|
||||
run: |
|
||||
set -x
|
||||
python -m pip install -e .[dev]
|
||||
python -m pip install -e .[dev,rl]
|
||||
|
||||
- name: Run build script for Sphinx pages
|
||||
env:
|
||||
|
||||
2
.github/workflows/python-package.yml
vendored
2
.github/workflows/python-package.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
- name: Install PrimAITE
|
||||
run: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
python -m pip install $PRIMAITE_WHEEL[dev,rl]
|
||||
|
||||
- name: Perform PrimAITE Setup
|
||||
run: |
|
||||
|
||||
@@ -43,7 +43,7 @@ cd ~\primaite
|
||||
python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
.\.venv\Scripts\activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install primaite-3.0.0-py3-none-any.whl[rl]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -66,7 +66,7 @@ mkdir ~/primaite
|
||||
cd ~/primaite
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install primaite-3.0.0-py3-none-any.whl[rl]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -105,7 +105,7 @@ source venv/bin/activate
|
||||
#### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
python3 -m pip install -e .[dev,rl]
|
||||
```
|
||||
|
||||
#### 6. Perform the PrimAITE setup:
|
||||
@@ -114,6 +114,9 @@ python3 -m pip install -e .[dev]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
#### Note
|
||||
*It is possible to install PrimAITE without Ray RLLib, StableBaselines3, or any deep learning libraries by omitting the `rl` flag in the pip install command.*
|
||||
|
||||
### Running PrimAITE
|
||||
|
||||
Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`.
|
||||
|
||||
@@ -82,7 +82,7 @@ Install PrimAITE
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install path/to/your/primaite.whl
|
||||
pip install path/to/your/primaite.whl[rl]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
@@ -133,12 +133,12 @@ of your choice:
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
|
||||
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).
|
||||
|
||||
|
||||
@@ -36,11 +36,8 @@ dependencies = [
|
||||
"polars==0.18.4",
|
||||
"prettytable==3.8.0",
|
||||
"PyYAML==6.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0",
|
||||
"pydantic==2.7.0",
|
||||
"ray[rllib] >= 2.9, < 3",
|
||||
"ipywidgets"
|
||||
]
|
||||
|
||||
@@ -55,6 +52,11 @@ license-files = ["LICENSE"]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.9, < 3",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
]
|
||||
dev = [
|
||||
"build==0.10.0",
|
||||
"flake8==6.0.0",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
@@ -97,7 +97,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
@@ -128,164 +127,3 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
174
src/primaite/session/ray_envs.py
Normal file
174
src/primaite/session/ray_envs.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import json
|
||||
from typing import Dict, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
@@ -3,7 +3,7 @@ import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.algorithms import ppo
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
|
||||
@@ -4,7 +4,8 @@ import yaml
|
||||
from gymnasium.core import ObsType
|
||||
from numpy import ndarray
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
|
||||
|
||||
Reference in New Issue
Block a user