From 4032f3a2a8eab06b4bbef07267fd7d9d15b9e845 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 17 Jul 2023 16:22:07 +0100 Subject: [PATCH] Change typehints after mypy analysis --- src/primaite/agents/hardcoded_acl.py | 1 + src/primaite/agents/hardcoded_node.py | 1 + src/primaite/agents/sb3.py | 1 + src/primaite/common/custom_typing.py | 4 ++-- src/primaite/environment/primaite_env.py | 10 ++++------ src/primaite/pol/red_agent_pol.py | 9 +++++---- 6 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 98c1d7d9..0ac5022c 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -175,6 +175,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if protocol != "ANY": protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services + # TODO: This should throw an error because protocol is a string matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index c00cf421..b74c3a0b 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -101,6 +101,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): property_action, action_service_index, ] + # TODO: transform_action_node_enum takes only one argument, not sure why two are given here. action = transform_action_node_enum(action, action_dict) action = get_new_action(action, action_dict) # We can only perform 1 action on each step diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 5f04acc0..462360a0 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -39,6 +39,7 @@ class SB3Agent(AgentSessionABC): msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) raise ValueError(msg) + self._agent_class: Union[PPO, A2C] if self._training_config.agent_identifier == AgentIdentifier.PPO: self._agent_class = PPO elif self._training_config.agent_identifier == AgentIdentifier.A2C: diff --git a/src/primaite/common/custom_typing.py b/src/primaite/common/custom_typing.py index e01c8713..4130e71a 100644 --- a/src/primaite/common/custom_typing.py +++ b/src/primaite/common/custom_typing.py @@ -1,8 +1,8 @@ -from typing import TypeVar +from typing import Union from primaite.nodes.active_node import ActiveNode from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode -NodeUnion = TypeVar("NodeUnion", ServiceNode, ActiveNode, PassiveNode) +NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode] """A Union of ActiveNode, PassiveNode, and ServiceNode.""" diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index d1c8adf5..f78b5f8d 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,7 +5,7 @@ import logging import uuid as uuid from pathlib import Path from random import choice, randint, sample, uniform -from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Final, List, Tuple, TYPE_CHECKING, Union import networkx as nx import numpy as np @@ -118,8 +118,7 @@ class Primaite(Env): self.green_iers_reference: Dict[str, IER] = {} # Create a dictionary to hold all the node PoLs (this will come from an external source) - # TODO: figure out type - self.node_pol = {} + self.node_pol: Dict[str, NodeStateInstructionGreen] = {} # Create a dictionary to hold all the red agent IERs (this will come from an external source) self.red_iers: Dict[str, IER] = {} @@ -149,8 +148,7 @@ class Primaite(Env): """The total number of time steps completed.""" # Create step info dictionary - # TODO: figure out type - self.step_info = {} + self.step_info: Dict[Any] = {} # Total reward self.total_reward: float = 0 @@ -315,7 +313,7 @@ class Primaite(Env): return self.env_obs - def step(self, action: int) -> tuple(np.ndarray, float, bool, Dict): + def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]: """ AI Gym Step function. diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index c9f75850..2801e8b0 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -4,6 +4,7 @@ from typing import Dict from networkx import MultiGraph, shortest_path +from primaite import getLogger from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState @@ -13,6 +14,8 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER +_LOGGER = getLogger(__name__) + _VERBOSE: bool = False @@ -270,8 +273,7 @@ def apply_red_agent_node_pol( # Do nothing, service not on this node pass else: - if _VERBOSE: - print("Node Red Agent PoL not allowed - misconfiguration") + _LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration") # Only apply the PoL if the checks have passed (based on the initiator type) if passed_checks: @@ -292,8 +294,7 @@ def apply_red_agent_node_pol( if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.set_file_system_state(state) else: - if _VERBOSE: - print("Node Red Agent PoL not allowed - did not pass checks") + _LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks") else: # PoL is not valid in this time step pass