Merge branch 'dev' into feature/build-pipeline-precommit

This commit is contained in:
Marek Wolan
2023-06-27 15:49:49 +01:00
2 changed files with 47 additions and 5 deletions

View File

@@ -107,6 +107,7 @@ class Primaite(Env):
# Create a dictionary to hold all the green IERs (this will come from an external source) # Create a dictionary to hold all the green IERs (this will come from an external source)
self.green_iers: Dict[str, IER] = {} self.green_iers: Dict[str, IER] = {}
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source) # Create a dictionary to hold all the node PoLs (this will come from an external source)
self.node_pol = {} self.node_pol = {}
@@ -309,6 +310,9 @@ class Primaite(Env):
for link_key, link_value in self.links.items(): for link_key, link_value in self.links.items():
link_value.clear_traffic() link_value.clear_traffic()
for link in self.links_reference.values():
link.clear_traffic()
# Create a Transaction (metric) object for this step # Create a Transaction (metric) object for this step
transaction = Transaction( transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count datetime.now(), self.agent_identifier, self.episode_count, self.step_count
@@ -346,7 +350,7 @@ class Primaite(Env):
self.network_reference, self.network_reference,
self.nodes_reference, self.nodes_reference,
self.links_reference, self.links_reference,
self.green_iers, self.green_iers_reference,
self.acl, self.acl,
self.step_count, self.step_count,
) # Network PoL ) # Network PoL
@@ -373,6 +377,7 @@ class Primaite(Env):
self.nodes_post_red, self.nodes_post_red,
self.nodes_reference, self.nodes_reference,
self.green_iers, self.green_iers,
self.green_iers_reference,
self.red_iers, self.red_iers,
self.step_count, self.step_count,
self.training_config, self.training_config,
@@ -864,6 +869,17 @@ class Primaite(Env):
ier_destination, ier_destination,
ier_mission_criticality, ier_mission_criticality,
) )
self.green_iers_reference[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source,
ier_destination,
ier_mission_criticality,
)
def create_red_ier(self, item): def create_red_ier(self, item):
""" """

View File

@@ -2,17 +2,21 @@
"""Implements reward function.""" """Implements reward function."""
from typing import Dict from typing import Dict
from primaite import getLogger
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
def calculate_reward_function( def calculate_reward_function(
initial_nodes, initial_nodes,
final_nodes, final_nodes,
reference_nodes, reference_nodes,
green_iers, green_iers,
green_iers_reference,
red_iers, red_iers,
step_count, step_count,
config_values, config_values,
@@ -68,14 +72,36 @@ def calculate_reward_function(
reward_value += config_values.red_ier_running reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted) # Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items(): for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step() start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step() stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step: if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running(): reference_blocked = reference_ier.get_is_running()
reward_value += ( live_blocked = ier_value.get_is_running()
config_values.green_ier_blocked ier_reward = (
* ier_value.get_mission_criticality() config_values.green_ier_blocked * ier_value.get_mission_criticality()
)
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
) )
return reward_value return reward_value