diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 7608e7db..be4cc434 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -107,6 +107,7 @@ class Primaite(Env): # Create a dictionary to hold all the green IERs (this will come from an external source) self.green_iers: Dict[str, IER] = {} + self.green_iers_reference: Dict[str, IER] = {} # Create a dictionary to hold all the node PoLs (this will come from an external source) self.node_pol = {} @@ -309,6 +310,9 @@ class Primaite(Env): for link_key, link_value in self.links.items(): link_value.clear_traffic() + for link in self.links_reference.values(): + link.clear_traffic() + # Create a Transaction (metric) object for this step transaction = Transaction( datetime.now(), self.agent_identifier, self.episode_count, self.step_count @@ -346,7 +350,7 @@ class Primaite(Env): self.network_reference, self.nodes_reference, self.links_reference, - self.green_iers, + self.green_iers_reference, self.acl, self.step_count, ) # Network PoL @@ -373,6 +377,7 @@ class Primaite(Env): self.nodes_post_red, self.nodes_reference, self.green_iers, + self.green_iers_reference, self.red_iers, self.step_count, self.training_config, @@ -864,6 +869,17 @@ class Primaite(Env): ier_destination, ier_mission_criticality, ) + self.green_iers_reference[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source, + ier_destination, + ier_mission_criticality, + ) def create_red_ier(self, item): """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index a620f9b3..aa9e4503 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -2,17 +2,21 @@ """Implements reward function.""" from typing import Dict +from primaite import getLogger from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.service import Service from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode +_LOGGER = getLogger(__name__) + def calculate_reward_function( initial_nodes, final_nodes, reference_nodes, green_iers, + green_iers_reference, red_iers, step_count, config_values, @@ -68,14 +72,36 @@ def calculate_reward_function( reward_value += config_values.red_ier_running # Go through each green IER - penalise if it's not running (weighted) + # but only if it's supposed to be running (it's running in reference) for ier_key, ier_value in green_iers.items(): + reference_ier = green_iers_reference[ier_key] start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - if not ier_value.get_is_running(): - reward_value += ( - config_values.green_ier_blocked - * ier_value.get_mission_criticality() + reference_blocked = reference_ier.get_is_running() + live_blocked = ier_value.get_is_running() + ier_reward = ( + config_values.green_ier_blocked * ier_value.get_mission_criticality() + ) + + if live_blocked and not reference_blocked: + _LOGGER.debug( + f"Applying reward of {ier_reward} because IER {ier_key} is blocked" + ) + reward_value += ier_reward + elif live_blocked and reference_blocked: + _LOGGER.debug( + ( + f"IER {ier_key} is blocked in the reference and live environments. " + f"Penalty of {ier_reward} was NOT applied." + ) + ) + elif not live_blocked and reference_blocked: + _LOGGER.debug( + ( + f"IER {ier_key} is blocked in the reference env but not in the live one. " + f"Penalty of {ier_reward} was NOT applied." + ) ) return reward_value