#1386: added the ability to set deterministic and seeding RNG when training and evaluating + the fix provided in #1535
This commit is contained in:
@@ -137,6 +137,12 @@ class TrainingConfig:
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system."
|
||||
|
||||
deterministic: bool = False
|
||||
"If true, the training will be deterministic"
|
||||
|
||||
seed: int = None
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
@@ -45,6 +45,7 @@ from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
@@ -221,7 +222,7 @@ class Primaite(Env):
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
@@ -232,11 +233,11 @@ class Primaite(Env):
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
|
||||
elif self.training_config.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
|
||||
else:
|
||||
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
# Set up a csv to store the results of the training
|
||||
@@ -405,8 +406,13 @@ class Primaite(Env):
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
self.__close__()
|
||||
|
||||
def __close__(self):
|
||||
"""Override close function."""
|
||||
"""
|
||||
Override close function
|
||||
"""
|
||||
self.csv_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
|
||||
@@ -14,8 +14,8 @@ from pathlib import Path
|
||||
from typing import Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
@@ -54,9 +54,6 @@ def run_generic(env: Primaite, config_values: TrainingConfig):
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@@ -90,7 +87,7 @@ def run_stable_baselines3_ppo(
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps)
|
||||
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
@@ -103,8 +100,19 @@ def run_stable_baselines3_ppo(
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
obs = env.reset()
|
||||
|
||||
for step in range(0, config_values.num_steps):
|
||||
action, _states = agent.predict(
|
||||
obs,
|
||||
deterministic=env.training_config.deterministic
|
||||
)
|
||||
# convert to int if action is a numpy array
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
env.close()
|
||||
|
||||
|
||||
@@ -138,7 +146,7 @@ def run_stable_baselines3_a2c(
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
@@ -151,7 +159,18 @@ def run_stable_baselines3_a2c(
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
obs = env.reset()
|
||||
|
||||
for step in range(0, config_values.num_steps):
|
||||
action, _states = agent.predict(
|
||||
obs,
|
||||
deterministic=env.training_config.deterministic
|
||||
)
|
||||
# convert to int if action is a numpy array
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user