Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test

This commit is contained in:
Czar Echavez
2023-06-27 14:16:10 +01:00
2 changed files with 48 additions and 5 deletions

View File

@@ -107,6 +107,7 @@ class Primaite(Env):
# Create a dictionary to hold all the green IERs (this will come from an external source)
self.green_iers: Dict[str, IER] = {}
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
self.node_pol = {}
@@ -196,6 +197,7 @@ class Primaite(Env):
try:
plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True)
# now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
@@ -315,6 +317,9 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
link_value.clear_traffic()
for link in self.links_reference.values():
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
@@ -352,7 +357,7 @@ class Primaite(Env):
self.network_reference,
self.nodes_reference,
self.links_reference,
self.green_iers,
self.green_iers_reference,
self.acl,
self.step_count,
) # Network PoL
@@ -379,6 +384,7 @@ class Primaite(Env):
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
self.green_iers_reference,
self.red_iers,
self.step_count,
self.training_config,
@@ -874,6 +880,17 @@ class Primaite(Env):
ier_destination,
ier_mission_criticality,
)
self.green_iers_reference[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source,
ier_destination,
ier_mission_criticality,
)
def create_red_ier(self, item):
"""

View File

@@ -2,17 +2,21 @@
"""Implements reward function."""
from typing import Dict
from primaite import getLogger
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
@@ -68,14 +72,36 @@ def calculate_reward_function(
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running():
reward_value += (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
reference_blocked = reference_ier.get_is_running()
live_blocked = ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality()
)
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value