Fix reference IERs
This commit is contained in:
@@ -109,6 +109,7 @@ class Primaite(Env):
|
||||
|
||||
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
||||
self.green_iers: Dict[str, IER] = {}
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
@@ -310,6 +311,9 @@ class Primaite(Env):
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
for link in self.links_reference.values():
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
@@ -348,7 +352,7 @@ class Primaite(Env):
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
@@ -375,6 +379,7 @@ class Primaite(Env):
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
@@ -866,6 +871,17 @@ class Primaite(Env):
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
self.green_iers_reference[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,9 @@ from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
@@ -13,6 +16,7 @@ def calculate_reward_function(
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
@@ -68,11 +72,15 @@ def calculate_reward_function(
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
ref_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
if not ier_value.get_is_running() and ref_ier.get_is_running():
|
||||
# what should happen if reference IER is blocked but live IER is running?
|
||||
_LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked")
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
|
||||
Reference in New Issue
Block a user