apply pre-commits
This commit is contained in:
@@ -14,8 +14,7 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -24,8 +23,9 @@ from primaite.common.enums import (
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState, ObservationType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
@@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import \
|
||||
NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -178,7 +176,6 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -200,7 +197,7 @@ class Primaite(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
now = datetime.now() # current date and time
|
||||
# now = datetime.now() # current date and time
|
||||
|
||||
file_path = session_path / f"network_{timestamp_str}.png"
|
||||
plt.savefig(file_path, format="PNG")
|
||||
@@ -239,7 +236,9 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
_LOGGER.info(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
@@ -311,7 +310,7 @@ class Primaite(Env):
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
|
||||
for link in self.links_reference.values():
|
||||
link.clear_traffic()
|
||||
|
||||
@@ -384,7 +383,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -1049,7 +1048,6 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info.
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -80,7 +80,13 @@ def calculate_reward_function(
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running() and ref_ier.get_is_running():
|
||||
# what should happen if reference IER is blocked but live IER is running?
|
||||
_LOGGER.debug(f"Applying penalty of {config_values.green_ier_blocked * ier_value.get_mission_criticality()} due to IER {ier_key} being blocked")
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"Applying penalty of "
|
||||
f"{config_values.green_ier_blocked * ier_value.get_mission_criticality()} "
|
||||
f"due to IER {ier_key} being blocked"
|
||||
)
|
||||
)
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
|
||||
Reference in New Issue
Block a user