Merge branch 'dev' into feature/build-pipeline-precommit
This commit is contained in:
@@ -107,6 +107,7 @@ class Primaite(Env):
|
|||||||
|
|
||||||
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
||||||
self.green_iers: Dict[str, IER] = {}
|
self.green_iers: Dict[str, IER] = {}
|
||||||
|
self.green_iers_reference: Dict[str, IER] = {}
|
||||||
|
|
||||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||||
self.node_pol = {}
|
self.node_pol = {}
|
||||||
@@ -309,6 +310,9 @@ class Primaite(Env):
|
|||||||
for link_key, link_value in self.links.items():
|
for link_key, link_value in self.links.items():
|
||||||
link_value.clear_traffic()
|
link_value.clear_traffic()
|
||||||
|
|
||||||
|
for link in self.links_reference.values():
|
||||||
|
link.clear_traffic()
|
||||||
|
|
||||||
# Create a Transaction (metric) object for this step
|
# Create a Transaction (metric) object for this step
|
||||||
transaction = Transaction(
|
transaction = Transaction(
|
||||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||||
@@ -346,7 +350,7 @@ class Primaite(Env):
|
|||||||
self.network_reference,
|
self.network_reference,
|
||||||
self.nodes_reference,
|
self.nodes_reference,
|
||||||
self.links_reference,
|
self.links_reference,
|
||||||
self.green_iers,
|
self.green_iers_reference,
|
||||||
self.acl,
|
self.acl,
|
||||||
self.step_count,
|
self.step_count,
|
||||||
) # Network PoL
|
) # Network PoL
|
||||||
@@ -373,6 +377,7 @@ class Primaite(Env):
|
|||||||
self.nodes_post_red,
|
self.nodes_post_red,
|
||||||
self.nodes_reference,
|
self.nodes_reference,
|
||||||
self.green_iers,
|
self.green_iers,
|
||||||
|
self.green_iers_reference,
|
||||||
self.red_iers,
|
self.red_iers,
|
||||||
self.step_count,
|
self.step_count,
|
||||||
self.training_config,
|
self.training_config,
|
||||||
@@ -864,6 +869,17 @@ class Primaite(Env):
|
|||||||
ier_destination,
|
ier_destination,
|
||||||
ier_mission_criticality,
|
ier_mission_criticality,
|
||||||
)
|
)
|
||||||
|
self.green_iers_reference[ier_id] = IER(
|
||||||
|
ier_id,
|
||||||
|
ier_start_step,
|
||||||
|
ier_end_step,
|
||||||
|
ier_load,
|
||||||
|
ier_protocol,
|
||||||
|
ier_port,
|
||||||
|
ier_source,
|
||||||
|
ier_destination,
|
||||||
|
ier_mission_criticality,
|
||||||
|
)
|
||||||
|
|
||||||
def create_red_ier(self, item):
|
def create_red_ier(self, item):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,17 +2,21 @@
|
|||||||
"""Implements reward function."""
|
"""Implements reward function."""
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from primaite import getLogger
|
||||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||||
from primaite.common.service import Service
|
from primaite.common.service import Service
|
||||||
from primaite.nodes.active_node import ActiveNode
|
from primaite.nodes.active_node import ActiveNode
|
||||||
from primaite.nodes.service_node import ServiceNode
|
from primaite.nodes.service_node import ServiceNode
|
||||||
|
|
||||||
|
_LOGGER = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def calculate_reward_function(
|
def calculate_reward_function(
|
||||||
initial_nodes,
|
initial_nodes,
|
||||||
final_nodes,
|
final_nodes,
|
||||||
reference_nodes,
|
reference_nodes,
|
||||||
green_iers,
|
green_iers,
|
||||||
|
green_iers_reference,
|
||||||
red_iers,
|
red_iers,
|
||||||
step_count,
|
step_count,
|
||||||
config_values,
|
config_values,
|
||||||
@@ -68,14 +72,36 @@ def calculate_reward_function(
|
|||||||
reward_value += config_values.red_ier_running
|
reward_value += config_values.red_ier_running
|
||||||
|
|
||||||
# Go through each green IER - penalise if it's not running (weighted)
|
# Go through each green IER - penalise if it's not running (weighted)
|
||||||
|
# but only if it's supposed to be running (it's running in reference)
|
||||||
for ier_key, ier_value in green_iers.items():
|
for ier_key, ier_value in green_iers.items():
|
||||||
|
reference_ier = green_iers_reference[ier_key]
|
||||||
start_step = ier_value.get_start_step()
|
start_step = ier_value.get_start_step()
|
||||||
stop_step = ier_value.get_end_step()
|
stop_step = ier_value.get_end_step()
|
||||||
if step_count >= start_step and step_count <= stop_step:
|
if step_count >= start_step and step_count <= stop_step:
|
||||||
if not ier_value.get_is_running():
|
reference_blocked = reference_ier.get_is_running()
|
||||||
reward_value += (
|
live_blocked = ier_value.get_is_running()
|
||||||
config_values.green_ier_blocked
|
ier_reward = (
|
||||||
* ier_value.get_mission_criticality()
|
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||||
|
)
|
||||||
|
|
||||||
|
if live_blocked and not reference_blocked:
|
||||||
|
_LOGGER.debug(
|
||||||
|
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||||
|
)
|
||||||
|
reward_value += ier_reward
|
||||||
|
elif live_blocked and reference_blocked:
|
||||||
|
_LOGGER.debug(
|
||||||
|
(
|
||||||
|
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||||
|
f"Penalty of {ier_reward} was NOT applied."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif not live_blocked and reference_blocked:
|
||||||
|
_LOGGER.debug(
|
||||||
|
(
|
||||||
|
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||||
|
f"Penalty of {ier_reward} was NOT applied."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return reward_value
|
return reward_value
|
||||||
|
|||||||
Reference in New Issue
Block a user