Change typehints after mypy analysis
This commit is contained in:
@@ -175,6 +175,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
# TODO: This should throw an error because protocol is a string
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
@@ -101,6 +101,7 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
|
||||
@@ -39,6 +39,7 @@ class SB3Agent(AgentSessionABC):
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_class: Union[PPO, A2C]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion = TypeVar("NodeUnion", ServiceNode, ActiveNode, PassiveNode)
|
||||
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import uuid as uuid
|
||||
from pathlib import Path
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -118,8 +118,7 @@ class Primaite(Env):
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
# TODO: figure out type
|
||||
self.node_pol = {}
|
||||
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
|
||||
|
||||
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
|
||||
self.red_iers: Dict[str, IER] = {}
|
||||
@@ -149,8 +148,7 @@ class Primaite(Env):
|
||||
"""The total number of time steps completed."""
|
||||
|
||||
# Create step info dictionary
|
||||
# TODO: figure out type
|
||||
self.step_info = {}
|
||||
self.step_info: Dict[Any] = {}
|
||||
|
||||
# Total reward
|
||||
self.total_reward: float = 0
|
||||
@@ -315,7 +313,7 @@ class Primaite(Env):
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action: int) -> tuple(np.ndarray, float, bool, Dict):
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||
"""
|
||||
AI Gym Step function.
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
@@ -13,6 +14,8 @@ from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
@@ -270,8 +273,7 @@ def apply_red_agent_node_pol(
|
||||
# Do nothing, service not on this node
|
||||
pass
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - misconfiguration")
|
||||
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
|
||||
|
||||
# Only apply the PoL if the checks have passed (based on the initiator type)
|
||||
if passed_checks:
|
||||
@@ -292,8 +294,7 @@ def apply_red_agent_node_pol(
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user