Merge remote-tracking branch 'origin/dev' into user-guide-feedback-core-fixes
This commit is contained in:
@@ -26,7 +26,7 @@ jobs:
|
||||
displayName: 'Install build dependencies'
|
||||
|
||||
- script: |
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
displayName: 'Install PrimAITE for docs autosummary'
|
||||
|
||||
- script: |
|
||||
|
||||
@@ -14,31 +14,31 @@ parameters:
|
||||
- name: matrix
|
||||
type: object
|
||||
default:
|
||||
- job_name: 'UbuntuPython38'
|
||||
py: '3.8'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'UbuntuPython38'
|
||||
# py: '3.8'
|
||||
# img: 'ubuntu-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'UbuntuPython311'
|
||||
py: '3.11'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: true
|
||||
publish_coverage: true
|
||||
- job_name: 'WindowsPython38'
|
||||
py: '3.8'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'WindowsPython38'
|
||||
# py: '3.8'
|
||||
# img: 'windows-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'WindowsPython311'
|
||||
py: '3.11'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'MacOSPython38'
|
||||
py: '3.8'
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
# - job_name: 'MacOSPython38'
|
||||
# py: '3.8'
|
||||
# img: 'macOS-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'MacOSPython311'
|
||||
py: '3.11'
|
||||
img: 'macOS-latest'
|
||||
@@ -82,12 +82,12 @@ stages:
|
||||
|
||||
- script: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
python -m pip install $PRIMAITE_WHEEL[dev,rl]
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev,rl]"
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
|
||||
2
.github/workflows/build-sphinx.yml
vendored
2
.github/workflows/build-sphinx.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
- name: Install PrimAITE for docs autosummary
|
||||
run: |
|
||||
set -x
|
||||
python -m pip install -e .[dev]
|
||||
python -m pip install -e .[dev,rl]
|
||||
|
||||
- name: Run build script for Sphinx pages
|
||||
env:
|
||||
|
||||
2
.github/workflows/python-package.yml
vendored
2
.github/workflows/python-package.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
- name: Install PrimAITE
|
||||
run: |
|
||||
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
|
||||
python -m pip install $PRIMAITE_WHEEL[dev]
|
||||
python -m pip install $PRIMAITE_WHEEL[dev,rl]
|
||||
|
||||
- name: Perform PrimAITE Setup
|
||||
run: |
|
||||
|
||||
@@ -43,7 +43,7 @@ cd ~\primaite
|
||||
python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
.\.venv\Scripts\activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install primaite-3.0.0-py3-none-any.whl[rl]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -66,7 +66,7 @@ mkdir ~/primaite
|
||||
cd ~/primaite
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install primaite-3.0.0-py3-none-any.whl[rl]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -105,7 +105,7 @@ source venv/bin/activate
|
||||
#### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
python3 -m pip install -e .[dev,rl]
|
||||
```
|
||||
|
||||
#### 6. Perform the PrimAITE setup:
|
||||
@@ -114,6 +114,9 @@ python3 -m pip install -e .[dev]
|
||||
primaite setup
|
||||
```
|
||||
|
||||
#### Note
|
||||
*It is possible to install PrimAITE without Ray RLLib, StableBaselines3, or any deep learning libraries by omitting the `rl` flag in the pip install command.*
|
||||
|
||||
### Running PrimAITE
|
||||
|
||||
Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`.
|
||||
|
||||
BIN
docs/_static/firewall_acl.png
vendored
BIN
docs/_static/firewall_acl.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 35 KiB After Width: | Height: | Size: 23 KiB |
@@ -19,4 +19,3 @@
|
||||
:recursive:
|
||||
|
||||
primaite
|
||||
tests
|
||||
|
||||
@@ -82,7 +82,7 @@ Install PrimAITE
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install path/to/your/primaite.whl
|
||||
pip install path/to/your/primaite.whl[rl]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
@@ -135,12 +135,12 @@ For example:
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
pip install -e .[dev]
|
||||
pip install -e .[dev,rl]
|
||||
|
||||
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).
|
||||
|
||||
|
||||
@@ -66,4 +66,7 @@ Glossary
|
||||
The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours.
|
||||
|
||||
Gymnasium
|
||||
PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards.
|
||||
PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards.
|
||||
|
||||
User app home
|
||||
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>/` on linux/darwin and `C:\\Users\\<username>\\primaite<version>` on Windows.
|
||||
|
||||
@@ -36,12 +36,10 @@ dependencies = [
|
||||
"polars==0.18.4",
|
||||
"prettytable==3.8.0",
|
||||
"PyYAML==6.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0",
|
||||
"pydantic==2.7.0",
|
||||
"ray[rllib] >= 2.9, < 3",
|
||||
"ipywidgets"
|
||||
"ipywidgets",
|
||||
"deepdiff"
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
@@ -55,6 +53,11 @@ license-files = ["LICENSE"]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
rl = [
|
||||
"ray[rllib] >= 2.20.0, < 3",
|
||||
"tensorflow==2.12.0",
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
]
|
||||
dev = [
|
||||
"build==0.10.0",
|
||||
"flake8==6.0.0",
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentActionHistoryItem(BaseModel):
|
||||
class AgentHistoryItem(BaseModel):
|
||||
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
|
||||
|
||||
timestep: int
|
||||
@@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel):
|
||||
response: RequestResponse
|
||||
"""The response sent back by the simulator for this action."""
|
||||
|
||||
reward: Optional[float] = None
|
||||
|
||||
|
||||
class AgentStartSettings(BaseModel):
|
||||
"""Configuration values for when an agent starts performing actions."""
|
||||
@@ -110,7 +112,7 @@ class AbstractAgent(ABC):
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
self.agent_settings = agent_settings or AgentSettings()
|
||||
self.action_history: List[AgentActionHistoryItem] = []
|
||||
self.history: List[AgentHistoryItem] = []
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -130,7 +132,7 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
|
||||
return self.reward_function.update(state=state, last_action_response=self.history[-1])
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
@@ -161,12 +163,16 @@ class AbstractAgent(ABC):
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.action_history.append(
|
||||
AgentActionHistoryItem(
|
||||
self.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
)
|
||||
)
|
||||
|
||||
def save_reward_to_history(self) -> None:
|
||||
"""Update the most recent history item with the reward value."""
|
||||
self.history[-1].reward = self.reward_function.current_reward
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
@@ -34,7 +34,7 @@ from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
WhereType = Optional[Iterable[Union[str, int]]]
|
||||
@@ -44,7 +44,7 @@ class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -64,7 +64,7 @@ class AbstractReward:
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward):
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
@@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
|
||||
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
self._last_request_failed: bool = False
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
|
||||
@@ -343,7 +343,7 @@ class SharedReward(AbstractReward):
|
||||
self.callback: Callable[[str], float] = default_callback
|
||||
"""Method that retrieves an agent's current reward given the agent's name."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Simply access the other agent's reward and return it."""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@@ -389,7 +389,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
|
||||
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
|
||||
@@ -160,6 +160,7 @@ class PrimaiteGame:
|
||||
agent = self.agents[agent_name]
|
||||
if self.step_counter > 0: # can't get reward before first action
|
||||
agent.update_reward(state=state)
|
||||
agent.save_reward_to_history()
|
||||
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
|
||||
agent.reward_function.total_reward += agent.reward_function.current_reward
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
"# Imports\n",
|
||||
"\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.game.agent.interface import AgentHistoryItem\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint"
|
||||
@@ -63,7 +63,7 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
|
||||
@@ -392,7 +392,7 @@
|
||||
"# Imports\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
|
||||
"from primaite.game.agent.interface import AgentHistoryItem\n",
|
||||
"import yaml\n",
|
||||
"from pprint import pprint\n"
|
||||
]
|
||||
@@ -444,7 +444,7 @@
|
||||
"source": [
|
||||
"def friendly_output_red_action(info):\n",
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'DONOTHING':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
|
||||
@@ -25,13 +25,13 @@
|
||||
"from primaite.game.game import PrimaiteGame\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from primaite import PRIMAITE_PATHS\n",
|
||||
"\n",
|
||||
"import ray\n",
|
||||
"from ray import air, tune\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
|
||||
"\n",
|
||||
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
|
||||
"# to copy the files to your user data path.\n",
|
||||
@@ -60,8 +60,8 @@
|
||||
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
|
||||
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
|
||||
" )\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
" )\n"
|
||||
]
|
||||
|
||||
@@ -18,8 +18,7 @@
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteRayEnv\n",
|
||||
"from ray.rllib.algorithms import ppo\n",
|
||||
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
|
||||
"from ray import air, tune\n",
|
||||
"import ray\n",
|
||||
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
||||
@@ -52,8 +51,8 @@
|
||||
"\n",
|
||||
"config = (\n",
|
||||
" PPOConfig()\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
|
||||
" .rollouts(num_rollout_workers=0)\n",
|
||||
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
|
||||
" .env_runners(num_env_runners=0)\n",
|
||||
" .training(train_batch_size=128)\n",
|
||||
")\n"
|
||||
]
|
||||
@@ -74,7 +73,7 @@
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 5 * 128}\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
@@ -97,7 +96,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -43,7 +43,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)"
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent['ref'] == 'defender':\n",
|
||||
" agent['agent_settings']['flatten_obs']=True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -177,7 +180,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -254,8 +254,8 @@
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_A'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].action_history[i].action\n",
|
||||
" green_action = env.game.agents['green_A'].history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
@@ -285,8 +285,8 @@
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_B'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].action_history[i].action\n",
|
||||
" green_action = env.game.agents['green_B'].history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
@@ -12,6 +11,7 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -61,7 +61,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {
|
||||
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
|
||||
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
|
||||
} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, action, state, reward)
|
||||
@@ -90,9 +90,10 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
PacketCapture.clear()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
@@ -126,166 +127,5 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
@@ -87,7 +87,7 @@ class PrimaiteIO:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
|
||||
177
src/primaite/session/ray_envs.py
Normal file
177
src/primaite/session/ray_envs.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
from typing import Dict, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
self.observation_space = gymnasium.spaces.Dict(
|
||||
{
|
||||
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
|
||||
for name, agent in self.agents.items()
|
||||
}
|
||||
)
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
self._obs_space_in_preferred_format = True
|
||||
self._action_space_in_preferred_format = True
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, ProxyAgent]:
|
||||
"""Grab a fresh reference to the agents from this episode's game object."""
|
||||
return {name: self.game.rl_agents[name] for name in self._agent_ids}
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
self.episode_counter += 1
|
||||
PacketCapture.clear()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
def step(
|
||||
self, actions: Dict[str, ActType]
|
||||
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
|
||||
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
|
||||
|
||||
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
|
||||
:type actions: Dict[str, ActType]
|
||||
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
|
||||
identifier.
|
||||
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
|
||||
"""
|
||||
step = self.game.step_counter
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.pre_timestep()
|
||||
self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
|
||||
# 3. Get next observations
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
# 4. Get rewards
|
||||
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
|
||||
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
|
||||
terminateds = {name: False for name, _ in self.agents.items()}
|
||||
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
|
||||
infos = {name: {} for name, _ in self.agents.items()}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(step, actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{step}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.episode_counter,
|
||||
"step": step,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
for agent_name in self._agent_ids:
|
||||
agent = self.game.rl_agents[agent_name]
|
||||
unflat_space = agent.observation_manager.space
|
||||
unflat_obs = agent.observation_manager.current_observation
|
||||
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
|
||||
|
||||
class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
|
||||
|
||||
def __init__(self, env_config: Dict) -> None:
|
||||
"""Initialise the environment.
|
||||
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
return self.env.reset(seed=seed)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
|
||||
"""Perform a step in the environment."""
|
||||
return self.env.step(action)
|
||||
|
||||
def close(self):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
@@ -21,6 +21,8 @@ class PacketCapture:
|
||||
The PCAPs are logged to: <simulation output directory>/<hostname>/<hostname>_<ip address>_pcap.log
|
||||
"""
|
||||
|
||||
_logger_instances: List[logging.Logger] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hostname: str,
|
||||
@@ -65,10 +67,12 @@ class PacketCapture:
|
||||
|
||||
if outbound:
|
||||
self.outbound_logger = logging.getLogger(self._get_logger_name(outbound))
|
||||
PacketCapture._logger_instances.append(self.outbound_logger)
|
||||
logger = self.outbound_logger
|
||||
else:
|
||||
self.inbound_logger = logging.getLogger(self._get_logger_name(outbound))
|
||||
logger = self.inbound_logger
|
||||
PacketCapture._logger_instances.append(self.inbound_logger)
|
||||
|
||||
logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs
|
||||
logger.addHandler(file_handler)
|
||||
@@ -122,3 +126,13 @@ class PacketCapture:
|
||||
if SIM_OUTPUT.save_pcap_logs:
|
||||
msg = frame.model_dump_json()
|
||||
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
"""Close all open PCAP file handlers."""
|
||||
for logger in PacketCapture._logger_instances:
|
||||
handlers = logger.handlers[:]
|
||||
for handler in handlers:
|
||||
logger.removeHandler(handler)
|
||||
handler.close()
|
||||
PacketCapture._logger_instances = []
|
||||
|
||||
@@ -3,7 +3,7 @@ import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.algorithms import ppo
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
|
||||
@@ -4,7 +4,8 @@ import yaml
|
||||
from gymnasium.core import ObsType
|
||||
from numpy import ndarray
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayMARLEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.interface import AgentActionHistoryItem
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
@@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
@@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent):
|
||||
state = game.get_sim_state()
|
||||
reward_value = comp.calculate(
|
||||
state,
|
||||
last_action_response=AgentActionHistoryItem(
|
||||
last_action_response=AgentHistoryItem(
|
||||
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user