Merged PR 56: #902 - Fix the reward functionality for node operating system state
#902 - replaced 'final_node_<placeholder>' with 'reference_node_<placeholder>' in methods for scoring of os_state, file_system_state, service state and operating state. This fixed the reward function so it is checked at each step for node operating system state, operating state, file system state and service state. - Added unit tests. Related work items: #902
This commit is contained in:
@@ -37,6 +37,8 @@ from primaite.pol.ier import IER
|
|||||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||||
from primaite.transactions.transaction import Transaction
|
from primaite.transactions.transaction import Transaction
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Primaite(Env):
|
class Primaite(Env):
|
||||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||||
@@ -145,14 +147,12 @@ class Primaite(Env):
|
|||||||
|
|
||||||
# Open the config file and build the environment laydown
|
# Open the config file and build the environment laydown
|
||||||
try:
|
try:
|
||||||
self.config_file = open(
|
self.config_file = open(self.config_values.config_filename_use_case, "r")
|
||||||
"config/" + self.config_values.config_filename_use_case, "r"
|
|
||||||
)
|
|
||||||
self.config_data = yaml.safe_load(self.config_file)
|
self.config_data = yaml.safe_load(self.config_file)
|
||||||
self.load_config()
|
self.load_config()
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.error("Could not load the environment configuration")
|
_LOGGER.error("Could not load the environment configuration")
|
||||||
logging.error("Exception occured", exc_info=True)
|
_LOGGER.error("Exception occured", exc_info=True)
|
||||||
|
|
||||||
# Store the node objects as node attributes
|
# Store the node objects as node attributes
|
||||||
# (This is so we can access them as objects)
|
# (This is so we can access them as objects)
|
||||||
@@ -180,8 +180,8 @@ class Primaite(Env):
|
|||||||
plt.savefig(filename, format="PNG")
|
plt.savefig(filename, format="PNG")
|
||||||
plt.clf()
|
plt.clf()
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.error("Could not save network diagram")
|
_LOGGER.error("Could not save network diagram")
|
||||||
logging.error("Exception occured", exc_info=True)
|
_LOGGER.error("Exception occured", exc_info=True)
|
||||||
print("Could not save network diagram")
|
print("Could not save network diagram")
|
||||||
|
|
||||||
# Define Observation Space
|
# Define Observation Space
|
||||||
@@ -223,7 +223,7 @@ class Primaite(Env):
|
|||||||
|
|
||||||
# Define Action Space - depends on action space type (Node or ACL)
|
# Define Action Space - depends on action space type (Node or ACL)
|
||||||
if self.action_type == ACTION_TYPE.NODE:
|
if self.action_type == ACTION_TYPE.NODE:
|
||||||
logging.info("Action space type NODE selected")
|
_LOGGER.info("Action space type NODE selected")
|
||||||
# Terms (for node action space):
|
# Terms (for node action space):
|
||||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||||
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
|
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
|
||||||
@@ -238,7 +238,7 @@ class Primaite(Env):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Action space type ACL selected")
|
_LOGGER.info("Action space type ACL selected")
|
||||||
# Terms (for ACL action space):
|
# Terms (for ACL action space):
|
||||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||||
@@ -273,10 +273,10 @@ class Primaite(Env):
|
|||||||
self.csv_writer = csv.writer(self.csv_file)
|
self.csv_writer = csv.writer(self.csv_file)
|
||||||
self.csv_writer.writerow(header)
|
self.csv_writer.writerow(header)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.error(
|
_LOGGER.error(
|
||||||
"Could not create csv file to hold average reward per episode"
|
"Could not create csv file to hold average reward per episode"
|
||||||
)
|
)
|
||||||
logging.error("Exception occured", exc_info=True)
|
_LOGGER.error("Exception occured", exc_info=True)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
@@ -322,7 +322,7 @@ class Primaite(Env):
|
|||||||
step_info: Additional information relating to this step
|
step_info: Additional information relating to this step
|
||||||
"""
|
"""
|
||||||
if self.step_count == 0:
|
if self.step_count == 0:
|
||||||
print("Episode: " + str(self.episode_count) + " running")
|
print(f"Episode: {str(self.episode_count)}")
|
||||||
|
|
||||||
# TEMP
|
# TEMP
|
||||||
done = False
|
done = False
|
||||||
@@ -402,7 +402,7 @@ class Primaite(Env):
|
|||||||
self.step_count,
|
self.step_count,
|
||||||
self.config_values,
|
self.config_values,
|
||||||
)
|
)
|
||||||
# print("Step reward: " + str(reward))
|
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||||
self.total_reward += reward
|
self.total_reward += reward
|
||||||
if self.step_count == self.episode_steps:
|
if self.step_count == self.episode_steps:
|
||||||
self.average_reward = self.total_reward / self.step_count
|
self.average_reward = self.total_reward / self.step_count
|
||||||
@@ -410,7 +410,7 @@ class Primaite(Env):
|
|||||||
# For evaluation, need to trigger the done value = True when
|
# For evaluation, need to trigger the done value = True when
|
||||||
# step count is reached in order to prevent neverending episode
|
# step count is reached in order to prevent neverending episode
|
||||||
done = True
|
done = True
|
||||||
print("Average reward: " + str(self.average_reward))
|
print(f" Average Reward: {str(self.average_reward)}")
|
||||||
# Load the reward into the transaction
|
# Load the reward into the transaction
|
||||||
transaction.set_reward(reward)
|
transaction.set_reward(reward)
|
||||||
|
|
||||||
@@ -757,7 +757,7 @@ class Primaite(Env):
|
|||||||
# Do nothing (bad formatting)
|
# Do nothing (bad formatting)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
logging.info("Environment configuration loaded")
|
_LOGGER.info("Environment configuration loaded")
|
||||||
print("Environment configuration loaded")
|
print("Environment configuration loaded")
|
||||||
|
|
||||||
def create_node(self, item):
|
def create_node(self, item):
|
||||||
@@ -1090,7 +1090,7 @@ class Primaite(Env):
|
|||||||
item: A config data item representing steps info
|
item: A config data item representing steps info
|
||||||
"""
|
"""
|
||||||
self.episode_steps = int(steps_info["steps"])
|
self.episode_steps = int(steps_info["steps"])
|
||||||
logging.info("Training episodes have " + str(self.episode_steps) + " steps")
|
_LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps")
|
||||||
|
|
||||||
def reset_environment(self):
|
def reset_environment(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -98,27 +98,27 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
|||||||
score += config_values.all_ok
|
score += config_values.all_ok
|
||||||
else:
|
else:
|
||||||
# We're different from the reference situation
|
# We're different from the reference situation
|
||||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
# Need to compare initial and reference (current) state of node (i.e. at every step)
|
||||||
if initial_node_operating_state == HARDWARE_STATE.ON:
|
if initial_node_operating_state == HARDWARE_STATE.ON:
|
||||||
if final_node_operating_state == HARDWARE_STATE.OFF:
|
if reference_node_operating_state == HARDWARE_STATE.OFF:
|
||||||
score += config_values.off_should_be_on
|
score += config_values.off_should_be_on
|
||||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||||
score += config_values.resetting_should_be_on
|
score += config_values.resetting_should_be_on
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_operating_state == HARDWARE_STATE.OFF:
|
elif initial_node_operating_state == HARDWARE_STATE.OFF:
|
||||||
if final_node_operating_state == HARDWARE_STATE.ON:
|
if reference_node_operating_state == HARDWARE_STATE.ON:
|
||||||
score += config_values.on_should_be_off
|
score += config_values.on_should_be_off
|
||||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||||
score += config_values.resetting_should_be_off
|
score += config_values.resetting_should_be_off
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_operating_state == HARDWARE_STATE.RESETTING:
|
elif initial_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||||
if final_node_operating_state == HARDWARE_STATE.ON:
|
if reference_node_operating_state == HARDWARE_STATE.ON:
|
||||||
score += config_values.on_should_be_resetting
|
score += config_values.on_should_be_resetting
|
||||||
elif final_node_operating_state == HARDWARE_STATE.OFF:
|
elif reference_node_operating_state == HARDWARE_STATE.OFF:
|
||||||
score += config_values.off_should_be_resetting
|
score += config_values.off_should_be_resetting
|
||||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||||
score += config_values.resetting
|
score += config_values.resetting
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@@ -148,29 +148,29 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
|||||||
score += config_values.all_ok
|
score += config_values.all_ok
|
||||||
else:
|
else:
|
||||||
# We're different from the reference situation
|
# We're different from the reference situation
|
||||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
# Need to compare initial and reference (current) state of node (i.e. at every step)
|
||||||
if initial_node_os_state == SOFTWARE_STATE.GOOD:
|
if initial_node_os_state == SOFTWARE_STATE.GOOD:
|
||||||
if final_node_os_state == SOFTWARE_STATE.PATCHING:
|
if reference_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching_should_be_good
|
score += config_values.patching_should_be_good
|
||||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised_should_be_good
|
score += config_values.compromised_should_be_good
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_os_state == SOFTWARE_STATE.PATCHING:
|
elif initial_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||||
if final_node_os_state == SOFTWARE_STATE.GOOD:
|
if reference_node_os_state == SOFTWARE_STATE.GOOD:
|
||||||
score += config_values.good_should_be_patching
|
score += config_values.good_should_be_patching
|
||||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised_should_be_patching
|
score += config_values.compromised_should_be_patching
|
||||||
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
|
elif reference_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching
|
score += config_values.patching
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||||
if final_node_os_state == SOFTWARE_STATE.GOOD:
|
if reference_node_os_state == SOFTWARE_STATE.GOOD:
|
||||||
score += config_values.good_should_be_compromised
|
score += config_values.good_should_be_compromised
|
||||||
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
|
elif reference_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching_should_be_compromised
|
score += config_values.patching_should_be_compromised
|
||||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised
|
score += config_values.compromised
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@@ -204,46 +204,46 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
|||||||
score += config_values.all_ok
|
score += config_values.all_ok
|
||||||
else:
|
else:
|
||||||
# We're different from the reference situation
|
# We're different from the reference situation
|
||||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
# Need to compare initial and reference state of node (i.e. at every step)
|
||||||
if initial_service.get_state() == SOFTWARE_STATE.GOOD:
|
if initial_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||||
if final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
if reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching_should_be_good
|
score += config_values.patching_should_be_good
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised_should_be_good
|
score += config_values.compromised_should_be_good
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||||
score += config_values.overwhelmed_should_be_good
|
score += config_values.overwhelmed_should_be_good
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_service.get_state() == SOFTWARE_STATE.PATCHING:
|
elif initial_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||||
score += config_values.good_should_be_patching
|
score += config_values.good_should_be_patching
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised_should_be_patching
|
score += config_values.compromised_should_be_patching
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||||
score += config_values.overwhelmed_should_be_patching
|
score += config_values.overwhelmed_should_be_patching
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching
|
score += config_values.patching
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||||
score += config_values.good_should_be_compromised
|
score += config_values.good_should_be_compromised
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching_should_be_compromised
|
score += config_values.patching_should_be_compromised
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised
|
score += config_values.compromised
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||||
score += config_values.overwhelmed_should_be_compromised
|
score += config_values.overwhelmed_should_be_compromised
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||||
score += config_values.good_should_be_overwhelmed
|
score += config_values.good_should_be_overwhelmed
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||||
score += config_values.patching_should_be_overwhelmed
|
score += config_values.patching_should_be_overwhelmed
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||||
score += config_values.compromised_should_be_overwhelmed
|
score += config_values.compromised_should_be_overwhelmed
|
||||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||||
score += config_values.overwhelmed
|
score += config_values.overwhelmed
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@@ -276,67 +276,67 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
|||||||
score += config_values.all_ok
|
score += config_values.all_ok
|
||||||
else:
|
else:
|
||||||
# We're different from the reference situation
|
# We're different from the reference situation
|
||||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
# Need to compare initial and reference state of node (i.e. at every step)
|
||||||
if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||||
if final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
if reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||||
score += config_values.repairing_should_be_good
|
score += config_values.repairing_should_be_good
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||||
score += config_values.restoring_should_be_good
|
score += config_values.restoring_should_be_good
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||||
score += config_values.corrupt_should_be_good
|
score += config_values.corrupt_should_be_good
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||||
score += config_values.destroyed_should_be_good
|
score += config_values.destroyed_should_be_good
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
elif initial_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||||
score += config_values.good_should_be_repairing
|
score += config_values.good_should_be_repairing
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||||
score += config_values.restoring_should_be_repairing
|
score += config_values.restoring_should_be_repairing
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||||
score += config_values.corrupt_should_be_repairing
|
score += config_values.corrupt_should_be_repairing
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||||
score += config_values.destroyed_should_be_repairing
|
score += config_values.destroyed_should_be_repairing
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||||
score += config_values.repairing
|
score += config_values.repairing
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||||
score += config_values.good_should_be_restoring
|
score += config_values.good_should_be_restoring
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||||
score += config_values.repairing_should_be_restoring
|
score += config_values.repairing_should_be_restoring
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||||
score += config_values.corrupt_should_be_restoring
|
score += config_values.corrupt_should_be_restoring
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||||
score += config_values.destroyed_should_be_restoring
|
score += config_values.destroyed_should_be_restoring
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||||
score += config_values.restoring
|
score += config_values.restoring
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||||
score += config_values.good_should_be_corrupt
|
score += config_values.good_should_be_corrupt
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||||
score += config_values.repairing_should_be_corrupt
|
score += config_values.repairing_should_be_corrupt
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||||
score += config_values.restoring_should_be_corrupt
|
score += config_values.restoring_should_be_corrupt
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||||
score += config_values.destroyed_should_be_corrupt
|
score += config_values.destroyed_should_be_corrupt
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||||
score += config_values.corrupt
|
score += config_values.corrupt
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||||
score += config_values.good_should_be_destroyed
|
score += config_values.good_should_be_destroyed
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||||
score += config_values.repairing_should_be_destroyed
|
score += config_values.repairing_should_be_destroyed
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||||
score += config_values.restoring_should_be_destroyed
|
score += config_values.restoring_should_be_destroyed
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||||
score += config_values.corrupt_should_be_destroyed
|
score += config_values.corrupt_should_be_destroyed
|
||||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||||
score += config_values.destroyed
|
score += config_values.destroyed
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -165,7 +165,9 @@ def load_config_values():
|
|||||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||||
config_values.time_delay = int(config_data["timeDelay"])
|
config_values.time_delay = int(config_data["timeDelay"])
|
||||||
config_values.config_filename_use_case = config_data["configFilename"]
|
config_values.config_filename_use_case = (
|
||||||
|
"config/" + config_data["configFilename"]
|
||||||
|
)
|
||||||
config_values.session_type = config_data["sessionType"]
|
config_values.session_type = config_data["sessionType"]
|
||||||
config_values.load_agent = bool(config_data["loadAgent"])
|
config_values.load_agent = bool(config_data["loadAgent"])
|
||||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
|
||||||
|
"The tests config root directory."
|
||||||
|
|||||||
0
tests/config/__init__.py
Normal file
0
tests/config/__init__.py
Normal file
125
tests/config/one_node_states_on_off_lay_down_config.yaml
Normal file
125
tests/config/one_node_states_on_off_lay_down_config.yaml
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
- itemType: ACTIONS
|
||||||
|
type: NODE
|
||||||
|
- itemType: STEPS
|
||||||
|
steps: 13
|
||||||
|
- itemType: PORTS
|
||||||
|
portsList:
|
||||||
|
- port: '21'
|
||||||
|
- itemType: SERVICES
|
||||||
|
serviceList:
|
||||||
|
- name: ftp
|
||||||
|
- itemType: NODE
|
||||||
|
id: '1'
|
||||||
|
name: node
|
||||||
|
baseType: SERVICE
|
||||||
|
nodeType: COMPUTER
|
||||||
|
priority: P1
|
||||||
|
hardwareState: 'ON'
|
||||||
|
ipAddress: 192.168.0.1
|
||||||
|
softwareState: GOOD
|
||||||
|
fileSystemState: GOOD
|
||||||
|
services:
|
||||||
|
- name: ftp
|
||||||
|
port: '21'
|
||||||
|
state: GOOD
|
||||||
|
- itemType: POSITION
|
||||||
|
positions:
|
||||||
|
- node: '1'
|
||||||
|
x_pos: 309
|
||||||
|
y_pos: 78
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '1'
|
||||||
|
startStep: 1
|
||||||
|
endStep: 3
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: FILE
|
||||||
|
protocol: NA
|
||||||
|
state: CORRUPT
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '2'
|
||||||
|
startStep: 3
|
||||||
|
endStep: 13
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: FILE
|
||||||
|
protocol: NA
|
||||||
|
state: GOOD
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '3'
|
||||||
|
startStep: 4
|
||||||
|
endStep: 6
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: OPERATING
|
||||||
|
protocol: NA
|
||||||
|
state: 'OFF'
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '4'
|
||||||
|
startStep: 6
|
||||||
|
endStep: 13
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: OPERATING
|
||||||
|
protocol: NA
|
||||||
|
state: 'ON'
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '5'
|
||||||
|
startStep: 7
|
||||||
|
endStep: 9
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: SERVICE
|
||||||
|
protocol: ftp
|
||||||
|
state: COMPROMISED
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '6'
|
||||||
|
startStep: 9
|
||||||
|
endStep: 13
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: SERVICE
|
||||||
|
protocol: ftp
|
||||||
|
state: GOOD
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '7'
|
||||||
|
startStep: 10
|
||||||
|
endStep: 12
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: OS
|
||||||
|
protocol: NA
|
||||||
|
state: COMPROMISED
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
|
- itemType: RED_POL
|
||||||
|
id: '8'
|
||||||
|
startStep: 12
|
||||||
|
endStep: 13
|
||||||
|
targetNodeId: '1'
|
||||||
|
initiator: DIRECT
|
||||||
|
type: OS
|
||||||
|
protocol: NA
|
||||||
|
state: GOOD
|
||||||
|
sourceNodeId: NA
|
||||||
|
sourceNodeService: NA
|
||||||
|
sourceNodeServiceState: NA
|
||||||
89
tests/config/one_node_states_on_off_main_config.yaml
Normal file
89
tests/config/one_node_states_on_off_main_config.yaml
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agentIdentifier: GENERIC
|
||||||
|
# Number of episodes to run per session
|
||||||
|
numEpisodes: 1
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
timeDelay: 1
|
||||||
|
# Filename of the scenario / laydown
|
||||||
|
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
sessionType: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
loadAgent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observationSpaceHighValue: 1000000000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
allOk: 0
|
||||||
|
# Node Operating State
|
||||||
|
offShouldBeOn: -10
|
||||||
|
offShouldBeResetting: -5
|
||||||
|
onShouldBeOff: -2
|
||||||
|
onShouldBeResetting: -5
|
||||||
|
resettingShouldBeOn: -5
|
||||||
|
resettingShouldBeOff: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node O/S or Service State
|
||||||
|
goodShouldBePatching: 2
|
||||||
|
goodShouldBeCompromised: 5
|
||||||
|
goodShouldBeOverwhelmed: 5
|
||||||
|
patchingShouldBeGood: -5
|
||||||
|
patchingShouldBeCompromised: 2
|
||||||
|
patchingShouldBeOverwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromisedShouldBeGood: -20
|
||||||
|
compromisedShouldBePatching: -20
|
||||||
|
compromisedShouldBeOverwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmedShouldBeGood: -20
|
||||||
|
overwhelmedShouldBePatching: -20
|
||||||
|
overwhelmedShouldBeCompromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
goodShouldBeRepairing: 2
|
||||||
|
goodShouldBeRestoring: 2
|
||||||
|
goodShouldBeCorrupt: 5
|
||||||
|
goodShouldBeDestroyed: 10
|
||||||
|
repairingShouldBeGood: -5
|
||||||
|
repairingShouldBeRestoring: 2
|
||||||
|
repairingShouldBeCorrupt: 2
|
||||||
|
repairingShouldBeDestroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoringShouldBeGood: -10
|
||||||
|
restoringShouldBeRepairing: -2
|
||||||
|
restoringShouldBeCorrupt: 1
|
||||||
|
restoringShouldBeDestroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corruptShouldBeGood: -10
|
||||||
|
corruptShouldBeRepairing: -10
|
||||||
|
corruptShouldBeRestoring: -10
|
||||||
|
corruptShouldBeDestroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyedShouldBeGood: -20
|
||||||
|
destroyedShouldBeRepairing: -20
|
||||||
|
destroyedShouldBeRestoring: -20
|
||||||
|
destroyedShouldBeCorrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
redIerRunning: -5
|
||||||
|
greenIerBlocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
osPatchingDuration: 5 # The time taken to patch the OS
|
||||||
|
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||||
|
servicePatchingDuration: 5 # The time taken to patch a service
|
||||||
|
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||||
|
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||||
|
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||||
@@ -1 +1,199 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from primaite.common.config_values_main import config_values_main
|
||||||
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
|
||||||
|
ACTION_SPACE_NODE_VALUES = 1
|
||||||
|
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _get_primaite_env_from_config(
|
||||||
|
main_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||||
|
):
|
||||||
|
"""Takes a config path and returns the created instance of Primaite."""
|
||||||
|
|
||||||
|
def load_config_values():
|
||||||
|
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||||
|
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||||
|
config_values.time_delay = int(config_data["timeDelay"])
|
||||||
|
config_values.config_filename_use_case = lay_down_config_path
|
||||||
|
config_values.session_type = config_data["sessionType"]
|
||||||
|
config_values.load_agent = bool(config_data["loadAgent"])
|
||||||
|
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||||
|
# Environment
|
||||||
|
config_values.observation_space_high_value = int(
|
||||||
|
config_data["observationSpaceHighValue"]
|
||||||
|
)
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
config_values.all_ok = int(config_data["allOk"])
|
||||||
|
# Node Operating State
|
||||||
|
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
|
||||||
|
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
|
||||||
|
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
|
||||||
|
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
|
||||||
|
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
|
||||||
|
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
|
||||||
|
config_values.resetting = int(config_data["resetting"])
|
||||||
|
# Node O/S or Service State
|
||||||
|
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
|
||||||
|
config_values.good_should_be_compromised = int(
|
||||||
|
config_data["goodShouldBeCompromised"]
|
||||||
|
)
|
||||||
|
config_values.good_should_be_overwhelmed = int(
|
||||||
|
config_data["goodShouldBeOverwhelmed"]
|
||||||
|
)
|
||||||
|
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
|
||||||
|
config_values.patching_should_be_compromised = int(
|
||||||
|
config_data["patchingShouldBeCompromised"]
|
||||||
|
)
|
||||||
|
config_values.patching_should_be_overwhelmed = int(
|
||||||
|
config_data["patchingShouldBeOverwhelmed"]
|
||||||
|
)
|
||||||
|
config_values.patching = int(config_data["patching"])
|
||||||
|
config_values.compromised_should_be_good = int(
|
||||||
|
config_data["compromisedShouldBeGood"]
|
||||||
|
)
|
||||||
|
config_values.compromised_should_be_patching = int(
|
||||||
|
config_data["compromisedShouldBePatching"]
|
||||||
|
)
|
||||||
|
config_values.compromised_should_be_overwhelmed = int(
|
||||||
|
config_data["compromisedShouldBeOverwhelmed"]
|
||||||
|
)
|
||||||
|
config_values.compromised = int(config_data["compromised"])
|
||||||
|
config_values.overwhelmed_should_be_good = int(
|
||||||
|
config_data["overwhelmedShouldBeGood"]
|
||||||
|
)
|
||||||
|
config_values.overwhelmed_should_be_patching = int(
|
||||||
|
config_data["overwhelmedShouldBePatching"]
|
||||||
|
)
|
||||||
|
config_values.overwhelmed_should_be_compromised = int(
|
||||||
|
config_data["overwhelmedShouldBeCompromised"]
|
||||||
|
)
|
||||||
|
config_values.overwhelmed = int(config_data["overwhelmed"])
|
||||||
|
# Node File System State
|
||||||
|
config_values.good_should_be_repairing = int(
|
||||||
|
config_data["goodShouldBeRepairing"]
|
||||||
|
)
|
||||||
|
config_values.good_should_be_restoring = int(
|
||||||
|
config_data["goodShouldBeRestoring"]
|
||||||
|
)
|
||||||
|
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
|
||||||
|
config_values.good_should_be_destroyed = int(
|
||||||
|
config_data["goodShouldBeDestroyed"]
|
||||||
|
)
|
||||||
|
config_values.repairing_should_be_good = int(
|
||||||
|
config_data["repairingShouldBeGood"]
|
||||||
|
)
|
||||||
|
config_values.repairing_should_be_restoring = int(
|
||||||
|
config_data["repairingShouldBeRestoring"]
|
||||||
|
)
|
||||||
|
config_values.repairing_should_be_corrupt = int(
|
||||||
|
config_data["repairingShouldBeCorrupt"]
|
||||||
|
)
|
||||||
|
config_values.repairing_should_be_destroyed = int(
|
||||||
|
config_data["repairingShouldBeDestroyed"]
|
||||||
|
)
|
||||||
|
config_values.repairing = int(config_data["repairing"])
|
||||||
|
config_values.restoring_should_be_good = int(
|
||||||
|
config_data["restoringShouldBeGood"]
|
||||||
|
)
|
||||||
|
config_values.restoring_should_be_repairing = int(
|
||||||
|
config_data["restoringShouldBeRepairing"]
|
||||||
|
)
|
||||||
|
config_values.restoring_should_be_corrupt = int(
|
||||||
|
config_data["restoringShouldBeCorrupt"]
|
||||||
|
)
|
||||||
|
config_values.restoring_should_be_destroyed = int(
|
||||||
|
config_data["restoringShouldBeDestroyed"]
|
||||||
|
)
|
||||||
|
config_values.restoring = int(config_data["restoring"])
|
||||||
|
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
|
||||||
|
config_values.corrupt_should_be_repairing = int(
|
||||||
|
config_data["corruptShouldBeRepairing"]
|
||||||
|
)
|
||||||
|
config_values.corrupt_should_be_restoring = int(
|
||||||
|
config_data["corruptShouldBeRestoring"]
|
||||||
|
)
|
||||||
|
config_values.corrupt_should_be_destroyed = int(
|
||||||
|
config_data["corruptShouldBeDestroyed"]
|
||||||
|
)
|
||||||
|
config_values.corrupt = int(config_data["corrupt"])
|
||||||
|
config_values.destroyed_should_be_good = int(
|
||||||
|
config_data["destroyedShouldBeGood"]
|
||||||
|
)
|
||||||
|
config_values.destroyed_should_be_repairing = int(
|
||||||
|
config_data["destroyedShouldBeRepairing"]
|
||||||
|
)
|
||||||
|
config_values.destroyed_should_be_restoring = int(
|
||||||
|
config_data["destroyedShouldBeRestoring"]
|
||||||
|
)
|
||||||
|
config_values.destroyed_should_be_corrupt = int(
|
||||||
|
config_data["destroyedShouldBeCorrupt"]
|
||||||
|
)
|
||||||
|
config_values.destroyed = int(config_data["destroyed"])
|
||||||
|
config_values.scanning = int(config_data["scanning"])
|
||||||
|
# IER status
|
||||||
|
config_values.red_ier_running = int(config_data["redIerRunning"])
|
||||||
|
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
|
||||||
|
# Patching / Reset durations
|
||||||
|
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
|
||||||
|
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
|
||||||
|
config_values.service_patching_duration = int(
|
||||||
|
config_data["servicePatchingDuration"]
|
||||||
|
)
|
||||||
|
config_values.file_system_repairing_limit = int(
|
||||||
|
config_data["fileSystemRepairingLimit"]
|
||||||
|
)
|
||||||
|
config_values.file_system_restoring_limit = int(
|
||||||
|
config_data["fileSystemRestoringLimit"]
|
||||||
|
)
|
||||||
|
config_values.file_system_scanning_limit = int(
|
||||||
|
config_data["fileSystemScanningLimit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
config_file_main = open(main_config_path, "r")
|
||||||
|
config_data = yaml.safe_load(config_file_main)
|
||||||
|
# Create a config class
|
||||||
|
config_values = config_values_main()
|
||||||
|
# Load in config data
|
||||||
|
load_config_values()
|
||||||
|
env = Primaite(config_values, [])
|
||||||
|
config_values.num_steps = env.episode_steps
|
||||||
|
|
||||||
|
if env.config_values.agent_identifier == "GENERIC":
|
||||||
|
run_generic(env, config_values)
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def run_generic(env, config_values):
|
||||||
|
"""Run against a generic agent."""
|
||||||
|
# Reset the environment at the start of the episode
|
||||||
|
# env.reset()
|
||||||
|
for episode in range(0, config_values.num_episodes):
|
||||||
|
for step in range(0, config_values.num_steps):
|
||||||
|
# Send the observation space to the agent to get an action
|
||||||
|
# TEMP - random action for now
|
||||||
|
# action = env.blue_agent_action(obs)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
|
||||||
|
# Run the simulation step on the live environment
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Break if done is True
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Introduce a delay between steps
|
||||||
|
time.sleep(config_values.time_delay / 1000)
|
||||||
|
|
||||||
|
# Reset the environment at the end of the episode
|
||||||
|
# env.reset()
|
||||||
|
|
||||||
|
# env.close()
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
from tests import TEST_CONFIG_ROOT
|
||||||
|
from tests.conftest import _get_primaite_env_from_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_rewards_are_being_penalised_at_each_step_function():
|
||||||
|
"""
|
||||||
|
Test that hardware state is penalised at each step.
|
||||||
|
|
||||||
|
When the initial state is OFF compared to reference state which is ON.
|
||||||
|
"""
|
||||||
|
env = _get_primaite_env_from_config(
|
||||||
|
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||||
|
lay_down_config_path=TEST_CONFIG_ROOT
|
||||||
|
/ "one_node_states_on_off_lay_down_config.yaml",
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated:
|
||||||
|
File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3)
|
||||||
|
Hardware State: onShouldBeOff = -2 (between Steps 4 & 6)
|
||||||
|
Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9)
|
||||||
|
Operating System State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12)
|
||||||
|
|
||||||
|
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26
|
||||||
|
Step Count: 13
|
||||||
|
|
||||||
|
For the 4 steps where this occurs the average reward is:
|
||||||
|
Average Reward: 2 (26 / 13)
|
||||||
|
"""
|
||||||
|
print("average reward", env.average_reward)
|
||||||
|
assert env.average_reward == 2.0
|
||||||
|
|||||||
Reference in New Issue
Block a user