#1386: added the ability to set deterministic and seeding RNG when training and evaluating + the fix provided in #1535

This commit is contained in:
Czar Echavez
2023-06-20 10:41:30 +01:00
parent dc0349c37b
commit 9fb30ffe1b
13 changed files with 1217 additions and 399 deletions

View File

@@ -137,6 +137,12 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system."
deterministic: bool = False
"If true, the training will be deterministic"
seed: int = None
"The random number generator seed to be used while training the agent"
def to_dict(self, json_serializable: bool = True):
"""
Serialise the ``TrainingConfig`` as dict.

View File

@@ -45,6 +45,7 @@ from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
@@ -221,7 +222,7 @@ class Primaite(Env):
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
elif self.training_config.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
@@ -232,11 +233,11 @@ class Primaite(Env):
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
elif self.training_config.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
else:
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
# Set up a csv to store the results of the training
@@ -405,8 +406,13 @@ class Primaite(Env):
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
self.__close__()
def __close__(self):
"""Override close function."""
"""
Override close function
"""
self.csv_file.close()
def init_acl(self):

View File

@@ -14,8 +14,8 @@ from pathlib import Path
from typing import Final, Union
from uuid import uuid4
import numpy as np
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.ppo import MlpPolicy as PPOMlp
@@ -54,9 +54,6 @@ def run_generic(env: Primaite, config_values: TrainingConfig):
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.close()
@@ -90,7 +87,7 @@ def run_stable_baselines3_ppo(
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
if config_values.session_type == "TRAINING":
# We're in a training session
@@ -103,8 +100,19 @@ def run_stable_baselines3_ppo(
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
for episode in range(0, config_values.num_episodes):
obs = env.reset()
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs,
deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = env.step(action)
env.close()
@@ -138,7 +146,7 @@ def run_stable_baselines3_a2c(
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
if config_values.session_type == "TRAINING":
# We're in a training session
@@ -151,7 +159,18 @@ def run_stable_baselines3_a2c(
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
for episode in range(0, config_values.num_episodes):
obs = env.reset()
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs,
deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = env.step(action)
env.close()