diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 883e844b..95a00f49 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -377,6 +377,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): time.sleep(self._training_config.time_delay / 1000) obs = self._env.reset() self._env.close() + super().evaluate() @classmethod def load(cls): diff --git a/tests/config/ppo_not_seeded_training_config.yaml b/tests/config/ppo_not_seeded_training_config.yaml index 23cff44e..14b3f087 100644 --- a/tests/config/ppo_not_seeded_training_config.yaml +++ b/tests/config/ppo_not_seeded_training_config.yaml @@ -60,10 +60,16 @@ observation_space: # - name: NODE_STATUSES # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session -num_episodes: 10 +num_train_episodes: 10 # Number of time_steps per episode -num_steps: 256 +num_train_steps: 256 + +# Number of episodes to run per session +num_eval_episodes: 10 + +# Number of time_steps per episode +num_eval_steps: 256 # Sets how often the agent will save a checkpoint (every n time episodes). # Set to 0 if no checkpoints are required. Default is 10 diff --git a/tests/config/ppo_seeded_training_config.yaml b/tests/config/ppo_seeded_training_config.yaml index 181331d9..a176c793 100644 --- a/tests/config/ppo_seeded_training_config.yaml +++ b/tests/config/ppo_seeded_training_config.yaml @@ -60,10 +60,16 @@ observation_space: # - name: NODE_STATUSES # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session -num_episodes: 10 +num_train_episodes: 10 # Number of time_steps per episode -num_steps: 256 +num_train_steps: 256 + +# Number of episodes to run per session +num_eval_episodes: 1 + +# Number of time_steps per episode +num_eval_steps: 256 # Sets how often the agent will save a checkpoint (every n time episodes). # Set to 0 if no checkpoints are required. Default is 10 diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 859b2ab3..0f378634 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -23,16 +23,11 @@ agent_identifier: RANDOM # "ANY" node and acl actions action_type: ANY # Number of episodes for training to run per session -num_train_episodes: 10 +num_train_episodes: 1 # Number of time_steps for training per episode -num_train_steps: 256 +num_train_steps: 15 -# Number of episodes for evaluation to run per session -num_eval_episodes: 10 - -# Number of time_steps for evaluation per episode -num_eval_steps: 256 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) diff --git a/tests/config/single_action_space_lay_down_config.yaml b/tests/config/single_action_space_lay_down_config.yaml index c80c0bab..9d05b84a 100644 --- a/tests/config/single_action_space_lay_down_config.yaml +++ b/tests/config/single_action_space_lay_down_config.yaml @@ -32,14 +32,6 @@ - name: ftp port: '21' state: COMPROMISED -- item_type: POSITION - positions: - - node: '1' - x_pos: 309 - y_pos: 78 - - node: '2' - x_pos: 200 - y_pos: 78 - item_type: RED_IER id: '3' start_step: 2 diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml index e0fc40ee..e2b24b41 100644 --- a/tests/config/test_random_red_main_config.yaml +++ b/tests/config/test_random_red_main_config.yaml @@ -29,16 +29,16 @@ random_red_agent: True # "ANY" node and acl actions action_type: NODE # Number of episodes for training to run per session -num_train_episodes: 10 +num_train_episodes: 2 # Number of time_steps for training per episode -num_train_steps: 256 +num_train_steps: 15 # Number of episodes for evaluation to run per session -num_eval_episodes: 10 +num_eval_episodes: 2 # Number of time_steps for evaluation per episode -num_eval_steps: 256 +num_eval_steps: 15 # Time delay between steps (for generic agents) time_delay: 1 diff --git a/tests/test_reward.py b/tests/test_reward.py index d1b56671..bb6eb1b0 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -47,6 +47,6 @@ def test_rewards_are_being_penalised_at_each_step_function( Average Reward: -8 (-120 / 15) """ with temp_primaite_session as session: - session.close() + session.evaluate() ev_rewards = session.eval_av_reward_per_episode_csv() assert ev_rewards[1] == -8.0