Remove redundant 'if TYPE_CHECKING' statements

This commit is contained in:
Marek Wolan
2023-07-18 10:21:06 +01:00
parent 393505b98b
commit a2ef4328dd
15 changed files with 37 additions and 65 deletions

View File

@@ -4,8 +4,9 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, Optional, Union
from uuid import uuid4
import primaite
@@ -16,10 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode
from primaite.environment.primaite_env import Primaite
from primaite.utils.session_metadata_parser import parse_session_metadata
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def get_session_path(session_timestamp: datetime) -> Path:

View File

@@ -4,8 +4,9 @@ from __future__ import annotations
import json
import shutil
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
@@ -19,10 +20,7 @@ from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
# TODO: verify type of env_config

View File

@@ -2,8 +2,9 @@
from __future__ import annotations
import json
from logging import Logger
from pathlib import Path
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, Union
import numpy as np
from stable_baselines3 import A2C, PPO
@@ -14,10 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
class SB3Agent(AgentSessionABC):

View File

@@ -1,12 +1,10 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
import numpy as np
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
if TYPE_CHECKING:
import numpy as np
class RandomAgent(HardCodedAgentSessionABC):
"""
@@ -15,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC):
Get a completely random action from the action space.
"""
def _calculate_action(self, obs: "np.ndarray") -> int:
def _calculate_action(self, obs: np.ndarray) -> int:
return self._env.action_space.sample()
@@ -26,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC):
All action spaces setup so dummy action is always 0 regardless of action type used.
"""
def _calculate_action(self, obs: "np.ndarray") -> int:
def _calculate_action(self, obs: np.ndarray) -> int:
return 0
@@ -37,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
A valid ACL action that has no effect; does nothing.
"""
def _calculate_action(self, obs: "np.ndarray") -> int:
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
nothing_action = transform_action_acl_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)
@@ -52,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
A valid Node action that has no effect; does nothing.
"""
def _calculate_action(self, obs: "np.ndarray") -> int:
def _calculate_action(self, obs: np.ndarray) -> int:
nothing_action = [1, "NONE", "ON", 0]
nothing_action = transform_action_node_enum(nothing_action)
nothing_action = get_new_action(nothing_action, self._env.action_dict)

View File

@@ -1,15 +1,13 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, TYPE_CHECKING, Union
from typing import Any, Dict, Final, Union
import yaml
from primaite import getLogger, USERS_CONFIG_DIR
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"

View File

@@ -2,8 +2,9 @@
from __future__ import annotations
from dataclasses import dataclass, field
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, Final, Optional, Union
import yaml
@@ -18,9 +19,6 @@ from primaite.common.enums import (
SessionType,
)
if TYPE_CHECKING:
from logging import Logger
_LOGGER: Logger = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"

View File

@@ -2,6 +2,7 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from logging import Logger
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
import numpy as np
@@ -15,12 +16,10 @@ from primaite.nodes.service_node import ServiceNode
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from logging import Logger
from primaite.environment.primaite_env import Primaite
_LOGGER: "Logger" = logging.getLogger(__name__)
_LOGGER: Logger = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):

View File

@@ -3,9 +3,10 @@
import copy
import logging
import uuid as uuid
from logging import Logger
from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Any, Dict, Final, List, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Final, List, Tuple, Union
import networkx as nx
import numpy as np
@@ -49,10 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
class Primaite(Env):

View File

@@ -1,5 +1,6 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from logging import Logger
from typing import Dict, TYPE_CHECKING, Union
from primaite import getLogger
@@ -10,12 +11,10 @@ from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
if TYPE_CHECKING:
from logging import Logger
from primaite.config.training_config import TrainingConfig
from primaite.pol.ier import IER
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def calculate_reward_function(

View File

@@ -5,14 +5,11 @@ import importlib.util
import os
import subprocess
import sys
from typing import TYPE_CHECKING
from logging import Logger
from primaite import getLogger, NOTEBOOKS_DIR
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def start_jupyter_session() -> None:

View File

@@ -6,7 +6,7 @@ from primaite import getLogger
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run() -> None:

View File

@@ -2,17 +2,14 @@
import filecmp
import os
import shutil
from logging import Logger
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger, NOTEBOOKS_DIR
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run(overwrite_existing: bool = True) -> None:

View File

@@ -12,7 +12,7 @@ from primaite import getLogger, USERS_CONFIG_DIR
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run(overwrite_existing: bool = True) -> None:

View File

@@ -1,12 +1,9 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
from typing import TYPE_CHECKING
from logging import Logger
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def run() -> None:

View File

@@ -1,16 +1,13 @@
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
import os
from logging import Logger
from pathlib import Path
from typing import TYPE_CHECKING
import pkg_resources
from primaite import getLogger
if TYPE_CHECKING:
from logging import Logger
_LOGGER: "Logger" = getLogger(__name__)
_LOGGER: Logger = getLogger(__name__)
def get_file_path(path: str) -> Path: