Merge remote-tracking branch 'origin/bugfix/1554-fix-not-learning-iers' into feature/917_Integrate_with_RLLib
This commit is contained in:
@@ -15,8 +15,7 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -25,8 +24,9 @@ from primaite.common.enums import (
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState, ObservationType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
@@ -36,15 +36,13 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import \
|
||||
NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
@@ -114,6 +112,7 @@ class Primaite(Env):
|
||||
|
||||
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
||||
self.green_iers: Dict[str, IER] = {}
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
@@ -182,7 +181,6 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -204,7 +202,7 @@ class Primaite(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
now = datetime.now() # current date and time
|
||||
# now = datetime.now() # current date and time
|
||||
|
||||
file_path = session_path / f"network_{timestamp_str}.png"
|
||||
plt.savefig(file_path, format="PNG")
|
||||
@@ -243,7 +241,9 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
_LOGGER.info(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
@@ -318,6 +318,9 @@ class Primaite(Env):
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
for link in self.links_reference.values():
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
@@ -355,7 +358,7 @@ class Primaite(Env):
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
@@ -382,11 +385,12 @@ class Primaite(Env):
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -881,6 +885,17 @@ class Primaite(Env):
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
self.green_iers_reference[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
@@ -1048,7 +1063,6 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
@@ -2,17 +2,21 @@
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
@@ -68,14 +72,36 @@ def calculate_reward_function(
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
_LOGGER.debug(
|
||||
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||
)
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
|
||||
return reward_value
|
||||
|
||||
Reference in New Issue
Block a user