Merged PR 411: merge 3.0.0 into dev :)
@@ -1,4 +1,5 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -13,15 +14,14 @@ from primaite.config.load import data_manipulation_config_path
|
||||
|
||||
_LOGGER = primaite.getLogger(__name__)
|
||||
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results"
|
||||
_RESULTS_ROOT.mkdir(exist_ok=True, parents=True)
|
||||
_MAJOR_V = primaite.__version__.split(".")[0]
|
||||
|
||||
_OUTPUT_ROOT: Final[Path] = _BENCHMARK_ROOT / "output"
|
||||
# Clear and recreate the output directory
|
||||
if _OUTPUT_ROOT.exists():
|
||||
shutil.rmtree(_OUTPUT_ROOT)
|
||||
_OUTPUT_ROOT.mkdir()
|
||||
_BENCHMARK_ROOT = Path(__file__).parent
|
||||
_RESULTS_ROOT: Final[Path] = _BENCHMARK_ROOT / "results" / f"v{_MAJOR_V}"
|
||||
_VERSION_ROOT: Final[Path] = _RESULTS_ROOT / f"v{primaite.__version__}"
|
||||
_SESSION_METADATA_ROOT: Final[Path] = _VERSION_ROOT / "session_metadata"
|
||||
|
||||
_SESSION_METADATA_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class BenchmarkSession:
|
||||
@@ -51,9 +51,6 @@ class BenchmarkSession:
|
||||
end_time: datetime
|
||||
"""End time for the session."""
|
||||
|
||||
session_metadata: Dict
|
||||
"""Dict containing the metadata for the session - used to generate benchmark report."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gym_env: BenchmarkPrimaiteGymEnv,
|
||||
@@ -182,8 +179,14 @@ def run(
|
||||
learning_rate=learning_rate,
|
||||
)
|
||||
session.train()
|
||||
session_metadata_dict[i] = session.session_metadata
|
||||
|
||||
# Dump the session metadata so that we're not holding it in memory as it's large
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "w") as file:
|
||||
json.dump(session.session_metadata, file, indent=4)
|
||||
|
||||
for i in range(1, number_of_sessions + 1):
|
||||
with open(_SESSION_METADATA_ROOT / f"{i}.json", "r") as file:
|
||||
session_metadata_dict[i] = json.load(file)
|
||||
# generate report
|
||||
build_benchmark_latex_report(
|
||||
benchmark_start_time=benchmark_start_time,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -124,7 +123,7 @@ def _plot_benchmark_metadata(
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
|
||||
@@ -140,7 +139,8 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
benchmarked. The combined_av_reward_per_episode is extracted from each,
|
||||
converted into a polars dataframe, and plotted as a scatter line in plotly.
|
||||
"""
|
||||
title = "PrimAITE Versions Learning Benchmark"
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
title = f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*"
|
||||
subtitle = "Rolling Av (Combined Session Av)"
|
||||
if title:
|
||||
if subtitle:
|
||||
@@ -172,7 +172,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure:
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
yaxis={"title": "Total Reward"},
|
||||
title=title,
|
||||
)
|
||||
fig["data"][0]["showlegend"] = True
|
||||
@@ -188,8 +188,6 @@ def build_benchmark_latex_report(
|
||||
v_str = f"v{primaite.__version__}"
|
||||
|
||||
version_result_dir = results_root_path / v_str
|
||||
if version_result_dir.exists():
|
||||
shutil.rmtree(version_result_dir)
|
||||
version_result_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# load the config file as dict
|
||||
@@ -200,7 +198,7 @@ def build_benchmark_latex_report(
|
||||
benchmark_metadata_dict = _build_benchmark_results_dict(
|
||||
start_datetime=benchmark_start_time, metadata_dict=session_metadata, config=cfg_data
|
||||
)
|
||||
|
||||
major_v = primaite.__version__.split(".")[0]
|
||||
with open(version_result_dir / f"{v_str}_benchmark_metadata.json", "w") as file:
|
||||
json.dump(benchmark_metadata_dict, file, indent=4)
|
||||
title = f"PrimAITE v{primaite.__version__.strip()} Learning Benchmark"
|
||||
@@ -241,9 +239,9 @@ def build_benchmark_latex_report(
|
||||
f"with each episode consisting of {steps} steps."
|
||||
)
|
||||
doc.append(
|
||||
f"\nThe mean reward per episode from each session is captured. This is then used to calculate a "
|
||||
f"combined average reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the combined average reward per session is calculated for "
|
||||
f"\nThe total reward per episode from each session is captured. This is then used to calculate an "
|
||||
f"caverage total reward per episode from the {sessions} individual sessions for smoothing. "
|
||||
f"Finally, a 25-widow rolling average of the average total reward per session is calculated for "
|
||||
f"further smoothing."
|
||||
)
|
||||
|
||||
@@ -294,14 +292,14 @@ def build_benchmark_latex_report(
|
||||
table.add_hline()
|
||||
|
||||
with doc.create(Section("Graphs")):
|
||||
with doc.create(Subsection(f"PrimAITE {primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(Subsection(f"v{primaite_version} Learning Benchmark Plot")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(this_version_plot_path))
|
||||
pic.add_caption(f"PrimAITE {primaite_version} Learning Benchmark Plot")
|
||||
|
||||
with doc.create(Subsection("PrimAITE All Versions Learning Benchmark Plot")):
|
||||
with doc.create(Subsection(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")):
|
||||
with doc.create(LatexFigure(position="h!")) as pic:
|
||||
pic.add_image(str(all_version_plot_path))
|
||||
pic.add_caption("PrimAITE All Versions Learning Benchmark Plot")
|
||||
pic.add_caption(f"Learning Benchmarking of All Released Versions under Major v{major_v}.*.*")
|
||||
|
||||
doc.generate_pdf(str(this_version_plot_path).replace(".png", ""), clean_tex=True)
|
||||
|
||||
|
Before Width: | Height: | Size: 90 KiB |
BIN
benchmark/results/v2/PrimAITE Versions Learning Benchmark.png
Normal file
|
After Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 225 KiB After Width: | Height: | Size: 225 KiB |
|
Before Width: | Height: | Size: 296 KiB |
BIN
benchmark/results/v3/PrimAITE Versions Learning Benchmark.png
Normal file
|
After Width: | Height: | Size: 91 KiB |
|
After Width: | Height: | Size: 295 KiB |
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Welcome to PrimAITE's documentation
|
||||
====================================
|
||||
@@ -54,7 +54,7 @@ It is agnostic to the number of agents, their action / observation spaces, and t
|
||||
It presents a public API providing a method for describing the current state of the simulation, a method that accepts action requests and provides responses, and a method that triggers a timestep advancement.
|
||||
The Game Layer converts the simulation into a playable game for the agent(s).
|
||||
|
||||
it translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
|
||||
It translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
|
||||
|
||||
Agents can either generate their own scripted behaviour or accept input behaviour from an RL agent.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _about:
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
PrimAITE |VERSION| Configuration
|
||||
********************************
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. role:: raw-html(raw)
|
||||
:format: html
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _Developer Tools:
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Example Jupyter Notebooks
|
||||
=========================
|
||||
@@ -77,6 +77,6 @@ The following extensions should now be installed
|
||||
:width: 300
|
||||
:align: center
|
||||
|
||||
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.10
|
||||
VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.11
|
||||
|
||||
You should now be able to interact with the notebook.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _getting-started:
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Glossary
|
||||
=============
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Request System
|
||||
**************
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
|
||||
Simulation
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
|
||||
Simulation Structure
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Simulation State
|
||||
================
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
|
||||
Defining variations in the config files
|
||||
================
|
||||
@@ -24,7 +24,7 @@ For each variation that could be used in a placeholder, there is a separate yaml
|
||||
|
||||
The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand ``&anchor``.
|
||||
|
||||
Learn more about YAML Aliases and Anchors here.
|
||||
Learn more about YAML Aliases and Anchors `here <https://yaml.org/spec/1.2.2/#3222-anchors-and-aliases>`_.
|
||||
|
||||
Schedule
|
||||
********
|
||||
|
||||
@@ -4,13 +4,15 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Customising red agents\n",
|
||||
"# Customising UC2 Red Agents\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n",
|
||||
"\n",
|
||||
"First, let's load the standard Data Manipulation config file, and see what the red agent does.\n",
|
||||
"\n",
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*"
|
||||
"*(For a full explanation of the Data Manipulation scenario, check out the data manipulation scenario notebook)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Data Manipulation Scenario\n"
|
||||
"# Data Manipulation Scenario\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -79,7 +81,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reinforcement learning details"
|
||||
"## Reinforcement learning details"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -692,7 +694,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Getting information out of PrimAITE"
|
||||
"# Getting information out of PrimAITE\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -160,7 +162,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
"source": [
|
||||
"# Requests and Responses\n",
|
||||
"\n",
|
||||
"Agents interact with the PrimAITE simulation via the Request system.\n"
|
||||
"Agents interact with the PrimAITE simulation via the Request system.\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Multi agent system using RLLIB\n",
|
||||
"# Train a Multi agent system using RLLIB\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
|
||||
]
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train a Single agent system using RLLib\n",
|
||||
"# Train a Single agent system using RLLib\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
|
||||
]
|
||||
},
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"source": [
|
||||
"# Training an SB3 Agent\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook will demonstrate how to use primaite to create and train a PPO agent, using a pre-defined configuration file."
|
||||
]
|
||||
},
|
||||
@@ -180,7 +182,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"source": [
|
||||
"# Using Episode Schedules\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
|
||||
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
|
||||
"\n",
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple multi-processing demo using SubprocVecEnv from SB3"
|
||||
"# Simple multi-processing demonstration\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This notebook uses SubprocVecEnv from SB3."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -139,7 +143,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -37,7 +37,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.average_reward_per_episode: Dict[int, float] = {}
|
||||
self.total_reward_per_episode: Dict[int, float] = {}
|
||||
"""Average rewards of agents per episode."""
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
self.average_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
"source": [
|
||||
"# Build a simulation using the Python API\n",
|
||||
"\n",
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects.\n"
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"Currently, this notebook manipulates the simulation by directly placing objects inside of the attributes of the network and domain. It should be refactored when proper methods exist for adding these objects."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
"source": [
|
||||
"# PrimAITE Router Simulation Demo\n",
|
||||
"\n",
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK\n",
|
||||
"\n",
|
||||
"This demo uses a modified version of the ARCD Use Case 2 Network (seen below) to demonstrate the capabilities of the Network simulator in PrimAITE."
|
||||
]
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
@@ -11,16 +11,16 @@ from typing import Any, Dict, Tuple, Union
|
||||
import polars as pl
|
||||
|
||||
|
||||
def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
def total_rewards_dict(total_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:param total_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode csv as a dict.
|
||||
"""
|
||||
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
df_dict = pl.read_csv(total_rewards_csv_file).to_dict()
|
||||
|
||||
return {int(v): df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
@@ -26,9 +26,9 @@ class SessionOutputWriter:
|
||||
Is used to write session outputs to csv file.
|
||||
"""
|
||||
|
||||
_AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
_TOTAL_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [
|
||||
"Episode",
|
||||
"Average Reward",
|
||||
"Total Reward",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -43,7 +43,7 @@ class SessionOutputWriter:
|
||||
:param env: PrimAITE gym environment.
|
||||
:type env: Primaite
|
||||
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
|
||||
If `false` it will output the average reward per episode, defaults to False
|
||||
If `false` it will output the total reward per episode, defaults to False
|
||||
:type transaction_writer: bool, optional
|
||||
:param learning_session: Set to `true` to indicate that the current session is a training session. This
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
@@ -56,7 +56,7 @@ class SessionOutputWriter:
|
||||
if self.transaction_writer:
|
||||
fn = f"all_transactions_{self._env.timestamp_str}.csv"
|
||||
else:
|
||||
fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
fn = f"total_reward_per_episode_{self._env.timestamp_str}.csv"
|
||||
|
||||
self._csv_file_path: "Path"
|
||||
if self.learning_session:
|
||||
@@ -94,7 +94,7 @@ class SessionOutputWriter:
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
else:
|
||||
header = self._AV_REWARD_PER_EPISODE_HEADER
|
||||
header = self._TOTAL_REWARD_PER_EPISODE_HEADER
|
||||
|
||||
if self._first_write:
|
||||
self._init_csv_writer()
|
||||
|
||||