From a2ef4328dd43a672ddb3f53d5604053e1a0c48ea Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 18 Jul 2023 10:21:06 +0100 Subject: [PATCH] Remove redundant 'if TYPE_CHECKING' statements --- src/primaite/agents/agent_abc.py | 8 +++----- src/primaite/agents/rllib.py | 8 +++----- src/primaite/agents/sb3.py | 8 +++----- src/primaite/agents/simple.py | 14 ++++++-------- src/primaite/config/lay_down_config.py | 8 +++----- src/primaite/config/training_config.py | 6 ++---- src/primaite/environment/observations.py | 5 ++--- src/primaite/environment/primaite_env.py | 8 +++----- src/primaite/environment/reward.py | 5 ++--- src/primaite/notebooks/__init__.py | 7 ++----- src/primaite/setup/old_installation_clean_up.py | 2 +- src/primaite/setup/reset_demo_notebooks.py | 7 ++----- src/primaite/setup/reset_example_configs.py | 2 +- src/primaite/setup/setup_app_dirs.py | 7 ++----- src/primaite/utils/package_data.py | 7 ++----- 15 files changed, 37 insertions(+), 65 deletions(-) diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index af860996..3c18e1f3 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -4,8 +4,9 @@ from __future__ import annotations import json from abc import ABC, abstractmethod from datetime import datetime +from logging import Logger from pathlib import Path -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, Union from uuid import uuid4 import primaite @@ -16,10 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite from primaite.utils.session_metadata_parser import parse_session_metadata -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def get_session_path(session_timestamp: datetime) -> Path: diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 8afc98a1..bde3a621 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -4,8 +4,9 @@ from __future__ import annotations import json import shutil from datetime import datetime +from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Optional, Union from uuid import uuid4 from ray.rllib.algorithms import Algorithm @@ -19,10 +20,7 @@ from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) # TODO: verify type of env_config diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 881426ab..5a9f9482 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -2,8 +2,9 @@ from __future__ import annotations import json +from logging import Logger from pathlib import Path -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, Union import numpy as np from stable_baselines3 import A2C, PPO @@ -14,10 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) class SB3Agent(AgentSessionABC): diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index bfc7bcf2..18ffa72b 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,12 +1,10 @@ # Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. -from typing import TYPE_CHECKING + +import numpy as np from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum -if TYPE_CHECKING: - import numpy as np - class RandomAgent(HardCodedAgentSessionABC): """ @@ -15,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC): Get a completely random action from the action space. """ - def _calculate_action(self, obs: "np.ndarray") -> int: + def _calculate_action(self, obs: np.ndarray) -> int: return self._env.action_space.sample() @@ -26,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC): All action spaces setup so dummy action is always 0 regardless of action type used. """ - def _calculate_action(self, obs: "np.ndarray") -> int: + def _calculate_action(self, obs: np.ndarray) -> int: return 0 @@ -37,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC): A valid ACL action that has no effect; does nothing. """ - def _calculate_action(self, obs: "np.ndarray") -> int: + def _calculate_action(self, obs: np.ndarray) -> int: nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] nothing_action = transform_action_acl_enum(nothing_action) nothing_action = get_new_action(nothing_action, self._env.action_dict) @@ -52,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC): A valid Node action that has no effect; does nothing. """ - def _calculate_action(self, obs: "np.ndarray") -> int: + def _calculate_action(self, obs: np.ndarray) -> int: nothing_action = [1, "NONE", "ON", 0] nothing_action = transform_action_node_enum(nothing_action) nothing_action = get_new_action(nothing_action, self._env.action_dict) diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 80b0f619..9cadc509 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,15 +1,13 @@ # Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. +from logging import Logger from pathlib import Path -from typing import Any, Dict, Final, TYPE_CHECKING, Union +from typing import Any, Dict, Final, Union import yaml from primaite import getLogger, USERS_CONFIG_DIR -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index f618b37c..f2229efb 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -2,8 +2,9 @@ from __future__ import annotations from dataclasses import dataclass, field +from logging import Logger from pathlib import Path -from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Final, Optional, Union import yaml @@ -18,9 +19,6 @@ from primaite.common.enums import ( SessionType, ) -if TYPE_CHECKING: - from logging import Logger - _LOGGER: Logger = getLogger(__name__) _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index ebc47043..0e613fe4 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -2,6 +2,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod +from logging import Logger from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union import numpy as np @@ -15,12 +16,10 @@ from primaite.nodes.service_node import ServiceNode # TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking # Therefore, this avoids circular dependency problem. if TYPE_CHECKING: - from logging import Logger - from primaite.environment.primaite_env import Primaite -_LOGGER: "Logger" = logging.getLogger(__name__) +_LOGGER: Logger = logging.getLogger(__name__) class AbstractObservationComponent(ABC): diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 8f34204b..4b830994 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,9 +3,10 @@ import copy import logging import uuid as uuid +from logging import Logger from pathlib import Path from random import choice, randint, sample, uniform -from typing import Any, Dict, Final, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Final, List, Tuple, Union import networkx as nx import numpy as np @@ -49,10 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod from primaite.transactions.transaction import Transaction from primaite.utils.session_output_writer import SessionOutputWriter -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) class Primaite(Env): diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index aad15246..92ef89ec 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -1,5 +1,6 @@ # Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. """Implements reward function.""" +from logging import Logger from typing import Dict, TYPE_CHECKING, Union from primaite import getLogger @@ -10,12 +11,10 @@ from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode if TYPE_CHECKING: - from logging import Logger - from primaite.config.training_config import TrainingConfig from primaite.pol.ier import IER -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def calculate_reward_function( diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index eaf10005..390fddb4 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -5,14 +5,11 @@ import importlib.util import os import subprocess import sys -from typing import TYPE_CHECKING +from logging import Logger from primaite import getLogger, NOTEBOOKS_DIR -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def start_jupyter_session() -> None: diff --git a/src/primaite/setup/old_installation_clean_up.py b/src/primaite/setup/old_installation_clean_up.py index 43950e4f..858ecfd9 100644 --- a/src/primaite/setup/old_installation_clean_up.py +++ b/src/primaite/setup/old_installation_clean_up.py @@ -6,7 +6,7 @@ from primaite import getLogger if TYPE_CHECKING: from logging import Logger -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def run() -> None: diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 775f43b5..f47af1dc 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -2,17 +2,14 @@ import filecmp import os import shutil +from logging import Logger from pathlib import Path -from typing import TYPE_CHECKING import pkg_resources from primaite import getLogger, NOTEBOOKS_DIR -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def run(overwrite_existing: bool = True) -> None: diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index df3b36a1..d50b24b5 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -12,7 +12,7 @@ from primaite import getLogger, USERS_CONFIG_DIR if TYPE_CHECKING: from logging import Logger -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def run(overwrite_existing: bool = True) -> None: diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index 56f16a08..68b5d772 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -1,12 +1,9 @@ # Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. -from typing import TYPE_CHECKING +from logging import Logger from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def run() -> None: diff --git a/src/primaite/utils/package_data.py b/src/primaite/utils/package_data.py index b9abca8f..96157b40 100644 --- a/src/primaite/utils/package_data.py +++ b/src/primaite/utils/package_data.py @@ -1,16 +1,13 @@ # Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence. import os +from logging import Logger from pathlib import Path -from typing import TYPE_CHECKING import pkg_resources from primaite import getLogger -if TYPE_CHECKING: - from logging import Logger - -_LOGGER: "Logger" = getLogger(__name__) +_LOGGER: Logger = getLogger(__name__) def get_file_path(path: str) -> Path: