apply pre-commits

This commit is contained in:
Marek Wolan
2023-06-27 11:20:18 +01:00
parent feead2cd44
commit e2d6abf833
2 changed files with 19 additions and 15 deletions

View File

@@ -14,8 +14,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -24,8 +23,9 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState, ObservationType,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config import training_config
@@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import \
NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
@@ -178,7 +176,6 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -200,7 +197,7 @@ class Primaite(Env):
try:
plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True)
now = datetime.now() # current date and time
# now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
@@ -239,7 +236,9 @@ class Primaite(Env):
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
_LOGGER.info(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
@@ -311,7 +310,7 @@ class Primaite(Env):
# Need to clear traffic on all links first
for link_key, link_value in self.links.items():
link_value.clear_traffic()
for link in self.links_reference.values():
link.clear_traffic()
@@ -384,7 +383,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
#print(f" Step {self.step_count} Reward: {str(reward)}")
# print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -1049,7 +1048,6 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
"""
Extracts action_info.

View File

@@ -2,11 +2,11 @@
"""Implements reward function."""
from typing import Dict
from primaite import getLogger
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
from primaite import getLogger
_LOGGER = getLogger(__name__)
@@ -80,7 +80,13 @@ def calculate_reward_function(
if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running() and ref_ier.get_is_running():
# what should happen if reference IER is blocked but live IER is running?
_LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked")
_LOGGER.debug(
(
f"Applying penalty of "
f"{config_values.green_ier_blocked * ier_value.get_mission_criticality()} "
f"due to IER {ier_key} being blocked"
)
)
reward_value += (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()