diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..1307a930 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -109,6 +109,7 @@ class Primaite(Env): # Create a dictionary to hold all the green IERs (this will come from an external source) self.green_iers: Dict[str, IER] = {} + self.green_iers_reference: Dict[str, IER] = {} # Create a dictionary to hold all the node PoLs (this will come from an external source) self.node_pol = {} @@ -310,6 +311,9 @@ class Primaite(Env): # Need to clear traffic on all links first for link_key, link_value in self.links.items(): link_value.clear_traffic() + + for link in self.links_reference.values(): + link.clear_traffic() # Create a Transaction (metric) object for this step transaction = Transaction( @@ -348,7 +352,7 @@ class Primaite(Env): self.network_reference, self.nodes_reference, self.links_reference, - self.green_iers, + self.green_iers_reference, self.acl, self.step_count, ) # Network PoL @@ -375,6 +379,7 @@ class Primaite(Env): self.nodes_post_red, self.nodes_reference, self.green_iers, + self.green_iers_reference, self.red_iers, self.step_count, self.training_config, @@ -866,6 +871,17 @@ class Primaite(Env): ier_destination, ier_mission_criticality, ) + self.green_iers_reference[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source, + ier_destination, + ier_mission_criticality, + ) def create_red_ier(self, item): """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index a620f9b3..777dcf74 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -6,6 +6,9 @@ from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.service import Service from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode +from primaite import getLogger + +_LOGGER = getLogger(__name__) def calculate_reward_function( @@ -13,6 +16,7 @@ def calculate_reward_function( final_nodes, reference_nodes, green_iers, + green_iers_reference, red_iers, step_count, config_values, @@ -68,11 +72,15 @@ def calculate_reward_function( reward_value += config_values.red_ier_running # Go through each green IER - penalise if it's not running (weighted) + # but only if it's supposed to be running (it's running in reference) for ier_key, ier_value in green_iers.items(): + ref_ier = green_iers_reference[ier_key] start_step = ier_value.get_start_step() stop_step = ier_value.get_end_step() if step_count >= start_step and step_count <= stop_step: - if not ier_value.get_is_running(): + if not ier_value.get_is_running() and ref_ier.get_is_running(): + # what should happen if reference IER is blocked but live IER is running? + _LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked") reward_value += ( config_values.green_ier_blocked * ier_value.get_mission_criticality()