Merged PR 115: Configure Different Episode and Step Counts for Training and Evaluation

## Summary
Training configs now have 2 different types of episode and step counts - one for train and one for evaluation.

`num_train_episodes`
`num_train_steps`
`num_eval_episodes`
`num_eval_steps`

## Test process
A test file `test_train_eval_episode_steps.py` has been implemented which runs train and evaluation session on two particular configs.

The train and evaluation sessions have different episodes and step count and the test checks that the output log files have the correct number of `total_steps` and `total_episodes`.

## Checklist
- [X] This PR is linked to a **work item**
- [X] I have performed **self-review** of the code
- [X] I have written **tests** for any new functionality added with this PR
- [X] I have updated the **documentation** if this PR changes or adds functionality
- [X] I have run **pre-commit** checks for code style

Related work items: #1566, #1589
This commit is contained in:
Sunil Samra
2023-07-12 13:34:58 +00:00
27 changed files with 389 additions and 169 deletions

View File

@@ -83,13 +83,24 @@ The environment config file consists of the following attributes:
The other configurable item is ``flatten`` which is false by default. When set to true, the observation space is flattened (turned into a 1-D vector). You should use this if your RL agent does not natively support observation space types like ``gym.Spaces.Tuple``.
* **num_episodes** [int]
* **num_train_episodes** [int]
This defines the number of episodes that the agent will train or be evaluated over.
This defines the number of episodes that the agent will train for.
* **num_steps** [int]
Determines the number of steps to run in each episode of the session
* **num_train_steps** [int]
Determines the number of steps to run in each episode of the training session.
* **num_eval_episodes** [int]
This defines the number of episodes that the agent will be evaluated over.
* **num_eval_steps** [int]
Determines the number of steps to run in each episode of the evaluation session.
* **time_delay** [int]

View File

@@ -162,12 +162,11 @@ class AgentSessionABC(ABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
@@ -218,10 +217,11 @@ class AgentSessionABC(ABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
self._plot_av_reward_per_episode(learning_session=False)
_LOGGER.info("Finished evaluation")
if self._can_evaluate:
self._plot_av_reward_per_episode(learning_session=False)
self._update_session_metadata_file()
self.is_eval = True
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self):
@@ -375,8 +375,8 @@ class HardCodedAgentSessionABC(AgentSessionABC):
self._env.set_as_eval() # noqa
self.is_eval = True
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
obs = self._env.reset()
for episode in range(episodes):
@@ -395,6 +395,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
time.sleep(self._training_config.time_delay / 1000)
obs = self._env.reset()
self._env.close()
super().evaluate()
@classmethod
def load(cls):

View File

@@ -97,8 +97,12 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
if not self.is_eval:
metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
else:
metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -122,13 +126,13 @@ class RLlibAgent(AgentSessionABC):
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
horizon=self._training_config.num_train_steps,
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
@@ -150,8 +154,8 @@ class RLlibAgent(AgentSessionABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
@@ -162,9 +166,6 @@ class RLlibAgent(AgentSessionABC):
super().learn()
# save agent
self.save()
def evaluate(
self,
**kwargs,

View File

@@ -65,11 +65,12 @@ class SB3Agent(AgentSessionABC):
session_path=self.session_path,
timestamp_str=self.timestamp_str,
)
self._agent = self._agent_class(
PPOMlp,
self._env,
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
n_steps=self._training_config.num_train_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
@@ -97,14 +98,14 @@ class SB3Agent(AgentSessionABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_train_steps
episodes = self._training_config.num_train_episodes
self.is_eval = False
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self._env._write_av_reward_per_episode() # noqa
self.save()
self._env.close()
super().learn()
@@ -121,8 +122,8 @@ class SB3Agent(AgentSessionABC):
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_eval_steps
episodes = self._training_config.num_eval_episodes
self._env.set_as_eval()
self.is_eval = True
if self._training_config.deterministic:
@@ -140,7 +141,7 @@ class SB3Agent(AgentSessionABC):
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env._write_av_reward_per_episode() # noqa
self._env.close()
super().evaluate()

View File

@@ -59,11 +59,19 @@ observation_space:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10

View File

@@ -60,11 +60,17 @@ class TrainingConfig:
action_type: ActionType = ActionType.ANY
"The ActionType to use"
num_episodes: int = 10
"The number of episodes to train over"
num_train_episodes: int = 10
"The number of episodes to train over during an training session"
num_steps: int = 256
"The number of steps in an episode"
num_train_steps: int = 256
"The number of steps in an episode during an training session"
num_eval_episodes: int = 1
"The number of episodes to train over during an evaluation session"
num_eval_steps: int = 256
"The number of steps in an episode during an evaluation session"
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes"
@@ -236,8 +242,17 @@ class TrainingConfig:
tc += f"{self.hard_coded_agent_view}, "
tc += f"{self.action_type}, "
tc += f"observation_space={self.observation_space}, "
tc += f"{self.num_episodes} episodes @ "
tc += f"{self.num_steps} steps"
if self.session_type is SessionType.TRAIN:
tc += f"{self.num_train_episodes} episodes @ "
tc += f"{self.num_train_steps} steps"
elif self.session_type is SessionType.EVAL:
tc += f"{self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
else:
tc += f"Training: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
tc += f"{self.num_eval_steps} steps"
return tc
@@ -285,24 +300,27 @@ def convert_legacy_training_config_dict(
agent_framework: AgentFramework = AgentFramework.SB3,
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256,
num_train_steps: int = 256,
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values.
:param agent_identifier: The red agent identifier to use as legacy training configs don't have agent_identifier
values.
:param action_type: The action space type to set as legacy training configs don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs don't have num_steps values.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param agent_identifier: The red agent identifier to use as legacy
training configs don't have agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_train_steps: The number of steps to set as legacy training configs
don't have num_train_steps values.
:return: The converted training config dict.
"""
config_dict = {
"agent_framework": agent_framework.name,
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps,
"num_train_steps": num_train_steps,
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
@@ -323,7 +341,8 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
"""
key_mapping = {
"agentIdentifier": None,
"numEpisodes": "num_episodes",
"numEpisodes": "num_train_episodes",
"numSteps": "num_train_steps",
"timeDelay": "time_delay",
"configFilename": None,
"sessionType": "session_type",

View File

@@ -84,7 +84,12 @@ class Primaite(Env):
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps = self.training_config.num_steps
if self.training_config.session_type == SessionType.TRAIN:
self.episode_steps = self.training_config.num_train_steps
elif self.training_config.session_type == SessionType.EVAL:
self.episode_steps = self.training_config.num_eval_steps
else:
self.episode_steps = self.training_config.num_train_steps
super(Primaite, self).__init__()
@@ -253,6 +258,12 @@ class Primaite(Env):
self.episode_count = 0
self.step_count = 0
self.total_step_count = 0
self.episode_steps = self.training_config.num_eval_steps
def _write_av_reward_per_episode(self):
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
def reset(self):
"""
@@ -261,10 +272,7 @@ class Primaite(Env):
Returns:
Environment observation space (reset)
"""
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
self._write_av_reward_per_episode()
self.episode_count += 1
# Don't need to reset links, as they are cleared and recalculated every

View File

@@ -90,7 +90,6 @@ def calculate_reward_function(
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value

View File

@@ -15,5 +15,6 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
:param av_rewards_csv_file: The average rewards per episode csv file path.
:return: The average rewards per episode cdv as a dict.
"""
d = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])}
df = pl.read_csv(av_rewards_csv_file).to_dict()
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}

View File

@@ -20,10 +20,12 @@ agent_identifier: PPO
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)

View File

@@ -22,11 +22,11 @@ agent_identifier: A2C
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 5
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 5
observation_space:
components:

View File

@@ -22,10 +22,11 @@ agent_identifier: RANDOM
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 5
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 5
observation_space:
components:

View File

@@ -22,10 +22,12 @@ agent_identifier: RANDOM
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 5
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 5
observation_space:
components:

View File

@@ -22,10 +22,11 @@ agent_identifier: RANDOM
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 5
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 5
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)

View File

@@ -18,11 +18,6 @@
- name: ftp
port: '21'
state: GOOD
- item_type: POSITION
positions:
- node: '1'
x_pos: 309
y_pos: 78
- item_type: RED_POL
id: '1'
start_step: 1

View File

@@ -22,10 +22,13 @@ agent_identifier: DUMMY
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 15
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1

View File

@@ -60,10 +60,16 @@ observation_space:
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
num_train_episodes: 10
# Number of time_steps per episode
num_steps: 256
num_train_steps: 256
# Number of episodes to run per session
num_eval_episodes: 10
# Number of time_steps per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10

View File

@@ -60,10 +60,16 @@ observation_space:
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
num_train_episodes: 10
# Number of time_steps per episode
num_steps: 256
num_train_steps: 256
# Number of episodes to run per session
num_eval_episodes: 1
# Number of time_steps per episode
num_eval_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10

View File

@@ -22,10 +22,12 @@ agent_identifier: RANDOM
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 15
# Number of episodes for training to run per session
num_train_episodes: 1
# Number of time_steps for training per episode
num_train_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)

View File

@@ -32,14 +32,6 @@
- name: ftp
port: '21'
state: COMPROMISED
- item_type: POSITION
positions:
- node: '1'
x_pos: 309
y_pos: 78
- node: '2'
x_pos: 200
y_pos: 78
- item_type: RED_IER
id: '3'
start_step: 2

View File

@@ -22,10 +22,17 @@ agent_identifier: RANDOM
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 5
# Number of episodes for training to run per session
num_train_episodes: 10
# Number of time_steps for training per episode
num_train_steps: 256
# Number of episodes for evaluation to run per session
num_eval_episodes: 10
# Number of time_steps for evaluation per episode
num_eval_steps: 256
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)

View File

@@ -28,10 +28,17 @@ random_red_agent: True
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 2
# Number of time_steps per episode
num_steps: 15
# Number of episodes for training to run per session
num_train_episodes: 2
# Number of time_steps for training per episode
num_train_steps: 15
# Number of episodes for evaluation to run per session
num_eval_episodes: 2
# Number of time_steps for evaluation per episode
num_eval_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1

View File

@@ -0,0 +1,153 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes for training to run per session
num_train_episodes: 3
# Number of time_steps for training per episode
num_train_steps: 25
# Number of episodes for evaluation to run per session
num_eval_episodes: 1
# Number of time_steps for evaluation per episode
num_eval_steps: 17
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,17 +1,16 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import datetime
import json
import shutil
import tempfile
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Union
from unittest.mock import patch
import pytest
from primaite import getLogger
from primaite.common.enums import AgentIdentifier
from primaite.environment.primaite_env import Primaite
from primaite.primaite_session import PrimaiteSession
from primaite.utils.session_output_reader import av_rewards_dict
@@ -48,6 +47,11 @@ class TempPrimaiteSession(PrimaiteSession):
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
return av_rewards_dict(self.evaluation_path / csv_file)
def metadata_file_as_dict(self) -> Dict[str, Any]:
"""Read the session_metadata.json file and return as a dict."""
with open(self.session_path / "session_metadata.json", "r") as file:
return json.load(file)
@property
def env(self) -> Primaite:
"""Direct access to the env for ease of testing."""
@@ -58,6 +62,7 @@ class TempPrimaiteSession(PrimaiteSession):
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@@ -129,58 +134,3 @@ def temp_session_path() -> Path:
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def _get_primaite_env_from_config(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
temp_session_path,
):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path,
session_path=session_path,
timestamp_str=timestamp_str,
)
config_values = env.training_config
config_values.num_steps = env.episode_steps
# TOOD: This needs t be refactored to happen outside. Should be part of
# a main Session class.
if env.training_config.agent_identifier is AgentIdentifier.RANDOM:
run_generic(env, config_values)
return env
def run_generic(env, config_values):
"""Run against a generic agent."""
# Reset the environment at the start of the episode
# env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
# env.reset()
# env.close()

View File

@@ -1,7 +1,10 @@
import pytest
from primaite import getLogger
from tests import TEST_CONFIG_ROOT
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
@@ -45,6 +48,5 @@ def test_rewards_are_being_penalised_at_each_step_function(
"""
with temp_primaite_session as session:
session.evaluate()
session.close()
ev_rewards = session.eval_av_reward_per_episode_csv()
assert ev_rewards[1] == -8.0

View File

@@ -12,8 +12,8 @@ def run_generic_set_actions(env: Primaite):
# Reset the environment at the start of the episode
# env.reset()
training_config = env.training_config
for episode in range(0, training_config.num_episodes):
for step in range(0, training_config.num_steps):
for episode in range(0, training_config.num_train_episodes):
for step in range(0, training_config.num_train_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)

View File

@@ -0,0 +1,42 @@
import pytest
from primaite import getLogger
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
_LOGGER = getLogger(__name__)
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "train_episode_step.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_eval_steps_differ_from_training(temp_primaite_session):
"""Uses PrimaiteSession class to compare number of episodes used for training and evaluation.
Train_episode_step.yaml main config:
num_train_steps = 25
num_train_episodes = 3
num_eval_steps = 17
num_eval_episodes = 1
"""
expected_learning_metadata = {"total_episodes": 3, "total_time_steps": 75}
expected_evaluation_metadata = {"total_episodes": 1, "total_time_steps": 17}
with temp_primaite_session as session:
# Run learning and check episode and step counts
session.learn()
assert session.env.actual_episode_count == expected_learning_metadata["total_episodes"]
assert session.env.total_step_count == expected_learning_metadata["total_time_steps"]
# Run evaluation and check episode and step counts
session.evaluate()
assert session.env.actual_episode_count == expected_evaluation_metadata["total_episodes"]
assert session.env.total_step_count == expected_evaluation_metadata["total_time_steps"]
# Load the session_metadata.json file and check that the both the
# learning and evaluation match what is expected above
metadata = session.metadata_file_as_dict()
assert metadata["learning"] == expected_learning_metadata
assert metadata["evaluation"] == expected_evaluation_metadata