diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 1307a930..bdfe00dd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -178,7 +176,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -200,7 +197,7 @@ class Primaite(Env): try: plt.tight_layout() nx.draw_networkx(self.network, with_labels=True) - now = datetime.now() # current date and time + # now = datetime.now() # current date and time file_path = session_path / f"network_{timestamp_str}.png" plt.savefig(file_path, format="PNG") @@ -239,7 +236,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -311,7 +310,7 @@ class Primaite(Env): # Need to clear traffic on all links first for link_key, link_value in self.links.items(): link_value.clear_traffic() - + for link in self.links_reference.values(): link.clear_traffic() @@ -384,7 +383,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + # print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1049,7 +1048,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 777dcf74..f48db259 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -2,11 +2,11 @@ """Implements reward function.""" from typing import Dict +from primaite import getLogger from primaite.common.enums import FileSystemState, HardwareState, SoftwareState from primaite.common.service import Service from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode -from primaite import getLogger _LOGGER = getLogger(__name__) @@ -80,7 +80,13 @@ def calculate_reward_function( if step_count >= start_step and step_count <= stop_step: if not ier_value.get_is_running() and ref_ier.get_is_running(): # what should happen if reference IER is blocked but live IER is running? - _LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked") + _LOGGER.debug( + ( + f"Applying penalty of " + f"{config_values.green_ier_blocked * ier_value.get_mission_criticality()} " + f"due to IER {ier_key} being blocked" + ) + ) reward_value += ( config_values.green_ier_blocked * ier_value.get_mission_criticality()