Merge remote-tracking branch 'origin/dev' into feature/1566-configure_episode-steps-learn-eval
# Conflicts: # src/primaite/config/training_config.py
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Utilities for PrimAITE."""
|
||||
|
||||
@@ -10,8 +10,7 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
"""
|
||||
Read an average rewards per episode csv file and return as a dict.
|
||||
|
||||
The dictionary keys are the episode number, and the values are the mean
|
||||
reward that episode.
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode cdv as a dict.
|
||||
|
||||
@@ -29,6 +29,18 @@ class SessionOutputWriter:
|
||||
transaction_writer: bool = False,
|
||||
learning_session: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialise the Session Output Writer.
|
||||
|
||||
:param env: PrimAITE gym environment.
|
||||
:type env: Primaite
|
||||
:param transaction_writer: If `true`, this will output a full account of every transaction taken by the agent.
|
||||
If `false` it will output the average reward per episode, defaults to False
|
||||
:type transaction_writer: bool, optional
|
||||
:param learning_session: Set to `true` to indicate that the current session is a training session. This
|
||||
determines the name of the folder which contains the final output csv. Defaults to True
|
||||
:type learning_session: bool, optional
|
||||
"""
|
||||
self._env = env
|
||||
self.transaction_writer = transaction_writer
|
||||
self.learning_session = learning_session
|
||||
@@ -68,8 +80,7 @@ class SessionOutputWriter:
|
||||
"""
|
||||
Write a row of session data.
|
||||
|
||||
:param data: The row of data to write. Can be a Tuple or an instance
|
||||
of Transaction.
|
||||
:param data: The row of data to write. Can be a Tuple or an instance of Transaction.
|
||||
"""
|
||||
if isinstance(data, Transaction):
|
||||
header, data = data.as_csv_data()
|
||||
|
||||
Reference in New Issue
Block a user