Change typehints after mypy analysis

This commit is contained in:
Marek Wolan
2023-07-17 16:22:07 +01:00
parent 432da5ca90
commit bfce2f9a7b
6 changed files with 14 additions and 12 deletions

View File

@@ -175,6 +175,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if protocol != "ANY":
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
# TODO: This should throw an error because protocol is a string
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules

View File

@@ -101,6 +101,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
property_action,
action_service_index,
]
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
action = transform_action_node_enum(action, action_dict)
action = get_new_action(action, action_dict)
# We can only perform 1 action on each step

View File

@@ -39,6 +39,7 @@ class SB3Agent(AgentSessionABC):
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_class: Union[PPO, A2C]
if self._training_config.agent_identifier == AgentIdentifier.PPO:
self._agent_class = PPO
elif self._training_config.agent_identifier == AgentIdentifier.A2C:

View File

@@ -1,8 +1,8 @@
from typing import TypeVar
from typing import Union
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
NodeUnion = TypeVar("NodeUnion", ServiceNode, ActiveNode, PassiveNode)
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""

View File

@@ -5,7 +5,7 @@ import logging
import uuid as uuid
from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Final, List, Tuple, TYPE_CHECKING, Union
import networkx as nx
import numpy as np
@@ -118,8 +118,7 @@ class Primaite(Env):
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
# TODO: figure out type
self.node_pol = {}
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
self.red_iers: Dict[str, IER] = {}
@@ -149,8 +148,7 @@ class Primaite(Env):
"""The total number of time steps completed."""
# Create step info dictionary
# TODO: figure out type
self.step_info = {}
self.step_info: Dict[Any] = {}
# Total reward
self.total_reward: float = 0
@@ -315,7 +313,7 @@ class Primaite(Env):
return self.env_obs
def step(self, action: int) -> tuple(np.ndarray, float, bool, Dict):
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
"""
AI Gym Step function.

View File

@@ -4,6 +4,7 @@ from typing import Dict
from networkx import MultiGraph, shortest_path
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
@@ -13,6 +14,8 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
_LOGGER = getLogger(__name__)
_VERBOSE: bool = False
@@ -270,8 +273,7 @@ def apply_red_agent_node_pol(
# Do nothing, service not on this node
pass
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - misconfiguration")
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
# Only apply the PoL if the checks have passed (based on the initiator type)
if passed_checks:
@@ -292,8 +294,7 @@ def apply_red_agent_node_pol(
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
if _VERBOSE:
print("Node Red Agent PoL not allowed - did not pass checks")
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass