Remove redundant 'if TYPE_CHECKING' statements
This commit is contained in:
@@ -4,8 +4,9 @@ from __future__ import annotations
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
@@ -16,10 +17,7 @@ from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
@@ -4,8 +4,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -19,10 +20,7 @@ from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: verify type of env_config
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -14,10 +15,7 @@ from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
@@ -15,7 +13,7 @@ class RandomAgent(HardCodedAgentSessionABC):
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
@@ -26,7 +24,7 @@ class DummyAgent(HardCodedAgentSessionABC):
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
@@ -37,7 +35,7 @@ class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
@@ -52,7 +50,7 @@ class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: "np.ndarray") -> int:
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Final, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -18,9 +19,6 @@ from primaite.common.enums import (
|
||||
SessionType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -15,12 +16,10 @@ from primaite.nodes.service_node import ServiceNode
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER: "Logger" = logging.getLogger(__name__)
|
||||
_LOGGER: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
import copy
|
||||
import logging
|
||||
import uuid as uuid
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Any, Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Final, List, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -49,10 +50,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements reward function."""
|
||||
from logging import Logger
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -10,12 +11,10 @@ from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
|
||||
@@ -5,14 +5,11 @@ import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def start_jupyter_session() -> None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from primaite import getLogger
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
|
||||
@@ -2,17 +2,14 @@
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger, NOTEBOOKS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
|
||||
@@ -12,7 +12,7 @@ from primaite import getLogger, USERS_CONFIG_DIR
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run(overwrite_existing: bool = True) -> None:
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_file_path(path: str) -> Path:
|
||||
|
||||
Reference in New Issue
Block a user