diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 5bbe881b..785d9757 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -300,7 +300,7 @@ def convert_legacy_training_config_dict( agent_framework: AgentFramework = AgentFramework.SB3, agent_identifier: AgentIdentifier = AgentIdentifier.PPO, action_type: ActionType = ActionType.ANY, - num_steps: int = 256, + num_train_steps: int = 256, ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -312,15 +312,15 @@ def convert_legacy_training_config_dict( training configs don't have agent_identifier values. :param action_type: The action space type to set as legacy training configs don't have action_type values. - :param num_steps: The number of steps to set as legacy training configs - don't have num_steps values. + :param num_train_steps: The number of steps to set as legacy training configs + don't have num_train_steps values. :return: The converted training config dict. """ config_dict = { "agent_framework": agent_framework.name, "agent_identifier": agent_identifier.name, "action_type": action_type.name, - "num_steps": num_steps, + "num_train_steps": num_train_steps, "sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name, } session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} diff --git a/tests/config/legacy_conversion/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml index 5ca80742..c57741f7 100644 --- a/tests/config/legacy_conversion/new_training_config.yaml +++ b/tests/config/legacy_conversion/new_training_config.yaml @@ -26,11 +26,6 @@ num_train_episodes: 10 # Number of time_steps for training per episode num_train_steps: 256 -# Number of episodes for evaluation to run per session -num_eval_episodes: 10 - -# Number of time_steps for evaluation per episode -num_eval_steps: 256 # Time delay between steps (for generic agents) time_delay: 10 # Type of session to be run (TRAINING or EVALUATION)