Merge branch 'dev' into feature/build-pipeline-precommit
This commit is contained in:
@@ -107,6 +107,7 @@ class Primaite(Env):
|
||||
|
||||
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
||||
self.green_iers: Dict[str, IER] = {}
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
@@ -309,6 +310,9 @@ class Primaite(Env):
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
for link in self.links_reference.values():
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
@@ -346,7 +350,7 @@ class Primaite(Env):
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
@@ -373,6 +377,7 @@ class Primaite(Env):
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
@@ -864,6 +869,17 @@ class Primaite(Env):
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
self.green_iers_reference[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
|
||||
@@ -2,17 +2,21 @@
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
@@ -68,14 +72,36 @@ def calculate_reward_function(
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
reference_blocked = reference_ier.get_is_running()
|
||||
live_blocked = ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
_LOGGER.debug(
|
||||
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||
)
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
|
||||
return reward_value
|
||||
|
||||
Reference in New Issue
Block a user