Merge remote-tracking branch 'origin/dev' into feature/1566-configure_episode-steps-learn-eval

# Conflicts:
#	src/primaite/config/training_config.py
This commit is contained in:
SunilSamra
2023-07-11 11:39:21 +01:00
51 changed files with 577 additions and 894 deletions

View File

@@ -0,0 +1 @@
"""Utilities for PrimAITE."""

View File

@@ -10,8 +10,7 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
"""
Read an average rewards per episode csv file and return as a dict.
The dictionary keys are the episode number, and the values are the mean
reward that episode.
The dictionary keys are the episode number, and the values are the mean reward that episode.
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.

View File

@@ -29,6 +29,18 @@ class SessionOutputWriter:
transaction_writer: bool = False,
learning_session: bool = True,
):
"""
Initialise the Session Output Writer.
:param env: PrimAITE gym environment.
:type env: Primaite
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
If `false` it will output the average reward per episode, defaults to False
:type transaction_writer: bool, optional
:param learning_session: Set to `true` to indicate that the current session is a training session. This
determines the name of the folder which contains the final output csv. Defaults to True
:type learning_session: bool, optional
"""
self._env = env
self.transaction_writer = transaction_writer
self.learning_session = learning_session
@@ -68,8 +80,7 @@ class SessionOutputWriter:
"""
Write a row of session data.
:param data: The row of data to write. Can be a Tuple or an instance
of Transaction.
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
"""
if isinstance(data, Transaction):
header, data = data.as_csv_data()