#917 - Synced with dev and added better logging
This commit is contained in:
@@ -2,9 +2,12 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
from logging import Logger, StreamHandler
|
||||
from bisect import bisect
|
||||
from logging import Formatter, LogRecord, StreamHandler
|
||||
from logging import Logger
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Final
|
||||
|
||||
import pkg_resources
|
||||
@@ -68,6 +71,33 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
|
||||
|
||||
|
||||
# region Setup Logging
|
||||
class _LevelFormatter(Formatter):
|
||||
"""
|
||||
A custom level-specific formatter.
|
||||
|
||||
Credit to: https://stackoverflow.com/a/68154386
|
||||
"""
|
||||
|
||||
def __init__(self, formats: Dict[int, str], **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
raise ValueError(
|
||||
"Format string must be passed to level-surrogate formatters, "
|
||||
"not this one"
|
||||
)
|
||||
|
||||
self.formats = sorted(
|
||||
(level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()
|
||||
)
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""Overrides ``Formatter.format``."""
|
||||
idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1)
|
||||
level, formatter = self.formats[idx]
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def _log_dir() -> Path:
|
||||
if sys.platform == "win32":
|
||||
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
|
||||
@@ -76,6 +106,16 @@ def _log_dir() -> Path:
|
||||
return dir_path
|
||||
|
||||
|
||||
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"]
|
||||
}
|
||||
)
|
||||
|
||||
LOG_DIR: Final[Path] = _log_dir()
|
||||
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
|
||||
|
||||
@@ -85,6 +125,7 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
|
||||
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
|
||||
|
||||
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
|
||||
|
||||
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
filename=LOG_PATH,
|
||||
maxBytes=10485760, # 10MB
|
||||
@@ -95,8 +136,8 @@ _STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"]
|
||||
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
|
||||
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Union, List
|
||||
from typing import Optional, Final, Dict, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.common.enums import OutputVerboseLevel
|
||||
from primaite.config import lay_down_config
|
||||
from primaite.config import training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
@@ -141,14 +141,13 @@ class AgentSessionABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
if self.output_verbose_level >= OutputVerboseLevel.INFO:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment "
|
||||
"(PrimAITE)"
|
||||
)
|
||||
_LOGGER.debug(
|
||||
f"The output directory for this agent is: {self.session_path}"
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment "
|
||||
f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(
|
||||
f"The output directory for this session is: {self.session_path}"
|
||||
)
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
@@ -165,6 +164,7 @@ class AgentSessionABC(ABC):
|
||||
**kwargs
|
||||
):
|
||||
if self._can_learn:
|
||||
_LOGGER.info("Finished learning")
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
@@ -176,7 +176,7 @@ class AgentSessionABC(ABC):
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
@@ -260,6 +260,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
|
||||
@@ -152,7 +152,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
|
||||
@@ -69,8 +69,9 @@ class SB3Agent(AgentSessionABC):
|
||||
(episode_count % checkpoint_n == 0)
|
||||
or (episode_count == self._training_config.num_episodes)
|
||||
):
|
||||
self._agent.save(
|
||||
self.checkpoints_path / f"sb3ppo_{episode_count}.zip")
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
@@ -86,6 +87,8 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
@@ -106,6 +109,9 @@ class SB3Agent(AgentSessionABC):
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
_LOGGER.info(f"Beginning evaluation for {episodes} episodes @"
|
||||
f" {time_steps} time steps...")
|
||||
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from primaite.common.enums import (
|
||||
NodeHardwareAction,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
NodePOLType
|
||||
)
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
|
||||
def transform_action_node_readable(action):
|
||||
|
||||
@@ -160,13 +160,26 @@ def setup(overwrite_existing: bool = True):
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: str, ldc: str):
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
:param tc: The training config filepath.
|
||||
:param ldc: The lay down config file path.
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
~/primaite/config/example_config/training/training_config_main.yaml.
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml.
|
||||
"""
|
||||
from primaite.main import run
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
if not ldc:
|
||||
ldc = data_manipulation_config_path()
|
||||
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
|
||||
@@ -83,9 +83,12 @@ class Protocol(Enum):
|
||||
|
||||
class SessionType(Enum):
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
TRAINING = 1
|
||||
EVALUATION = 2
|
||||
BOTH = 3
|
||||
TRAIN = 1
|
||||
"Train an agent"
|
||||
EVAL = 2
|
||||
"Evaluate an agent"
|
||||
TRAIN_EVAL = 3
|
||||
"Train then evaluate an agent"
|
||||
|
||||
|
||||
class VerboseLevel(IntEnum):
|
||||
@@ -141,7 +144,6 @@ class HardCodedAgentView(Enum):
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
@@ -149,7 +151,6 @@ class ActionType(Enum):
|
||||
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
BOX = 0
|
||||
MULTIDISCRETE = 1
|
||||
|
||||
|
||||
@@ -38,26 +38,23 @@ hard_coded_agent_view: FULL
|
||||
action_type: ANY
|
||||
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 100
|
||||
num_episodes: 10
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 100
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 3
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
|
||||
@@ -69,7 +69,7 @@ class TrainingConfig:
|
||||
"The delay between steps (ms). Applies to generic agents only"
|
||||
|
||||
# file
|
||||
session_type: SessionType = SessionType.TRAINING
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: str = False
|
||||
@@ -171,6 +171,7 @@ class TrainingConfig:
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system"
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
@@ -183,7 +184,7 @@ class TrainingConfig:
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"output_verbose_level": OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
}
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
@@ -211,6 +212,20 @@ class TrainingConfig:
|
||||
|
||||
return data
|
||||
|
||||
def __str__(self) -> str:
|
||||
tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, "
|
||||
if self.agent_framework is AgentFramework.RLLIB:
|
||||
tc += f"deep_learning_framework=" \
|
||||
f"{self.deep_learning_framework.name}, "
|
||||
tc += f"agent_identifier={self.agent_identifier.name}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, "
|
||||
tc += f"action_type={self.action_type.name}, "
|
||||
tc += f"observation_space={self.observation_space}, "
|
||||
tc += f"num_episodes={self.num_episodes}, "
|
||||
tc += f"num_steps={self.num_steps})"
|
||||
return tc
|
||||
|
||||
|
||||
def load(
|
||||
file_path: Union[str, Path],
|
||||
|
||||
@@ -15,7 +15,8 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -36,13 +37,15 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import \
|
||||
NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
@@ -61,12 +64,12 @@ class Primaite(Env):
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
"""
|
||||
The Primaite constructor.
|
||||
@@ -86,6 +89,7 @@ class Primaite(Env):
|
||||
self.training_config: TrainingConfig = training_config.load(
|
||||
training_config_path
|
||||
)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
self.episode_steps = self.training_config.num_steps
|
||||
@@ -207,16 +211,14 @@ class Primaite(Env):
|
||||
plt.savefig(file_path, format="PNG")
|
||||
plt.clf()
|
||||
except Exception:
|
||||
_LOGGER.error("Could not save network diagram")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
print("Could not save network diagram")
|
||||
_LOGGER.error("Could not save network diagram", exc_info=True)
|
||||
|
||||
# Initiate observation space
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.info("Action space type NODE selected")
|
||||
_LOGGER.debug("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
@@ -225,7 +227,7 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
_LOGGER.debug("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
@@ -236,11 +238,11 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.training_config.action_type == ActionType.ANY:
|
||||
_LOGGER.info("Action space type ANY selected - Node + ACL")
|
||||
_LOGGER.debug("Action space type ANY selected - Node + ACL")
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info(
|
||||
_LOGGER.error(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
@@ -301,17 +303,14 @@ class Primaite(Env):
|
||||
done: Indicates episode is complete if True
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self.training_config.time_delay / 1000)
|
||||
if self.step_count == 0:
|
||||
print(f"Episode: {str(self.episode_count)}")
|
||||
_LOGGER.info(f"Episode: {str(self.episode_count)}")
|
||||
|
||||
# TEMP
|
||||
done = False
|
||||
|
||||
self.step_count += 1
|
||||
self.total_step_count += 1
|
||||
# print("Episode step: " + str(self.step_count))
|
||||
|
||||
# Need to clear traffic on all links first
|
||||
for link_key, link_value in self.links.items():
|
||||
@@ -322,7 +321,8 @@ class Primaite(Env):
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
datetime.now(), self.agent_identifier, self.episode_count,
|
||||
self.step_count
|
||||
)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
||||
@@ -352,7 +352,8 @@ class Primaite(Env):
|
||||
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
# Reference
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_node_pol(self.nodes_reference, self.node_pol,
|
||||
self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
@@ -389,7 +390,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
_LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -397,7 +398,7 @@ class Primaite(Env):
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
print(f" Average Reward: {str(self.average_reward)}")
|
||||
_LOGGER.info(f" Average Reward: {str(self.average_reward)}")
|
||||
# Load the reward into the transaction
|
||||
transaction.set_reward(reward)
|
||||
|
||||
@@ -428,6 +429,7 @@ class Primaite(Env):
|
||||
self.timestamp_str
|
||||
)
|
||||
self.csv_file.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
@@ -435,9 +437,9 @@ class Primaite(Env):
|
||||
def output_link_status(self):
|
||||
"""Output the link status of all links to the console."""
|
||||
for link_key, link_value in self.links.items():
|
||||
print("Link ID: " + link_value.get_id())
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
print(
|
||||
_LOGGER.debug(
|
||||
" Protocol: "
|
||||
+ protocol.get_name().name
|
||||
+ ", Load: "
|
||||
@@ -457,11 +459,11 @@ class Primaite(Env):
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
@@ -529,7 +531,8 @@ class Primaite(Env):
|
||||
elif property_action == 1:
|
||||
# Patch (valid action if it's good or compromised)
|
||||
node.set_service_state(
|
||||
self.services_list[service_index], SoftwareState.PATCHING
|
||||
self.services_list[service_index],
|
||||
SoftwareState.PATCHING
|
||||
)
|
||||
else:
|
||||
# Node is not of Service Type
|
||||
@@ -589,7 +592,8 @@ class Primaite(Env):
|
||||
acl_rule_source = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_source_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
||||
if isinstance(node, ServiceNode) or isinstance(node,
|
||||
ActiveNode):
|
||||
acl_rule_source = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -598,7 +602,8 @@ class Primaite(Env):
|
||||
acl_rule_destination = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_destination_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
||||
if isinstance(node, ServiceNode) or isinstance(node,
|
||||
ActiveNode):
|
||||
acl_rule_destination = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -683,7 +688,8 @@ class Primaite(Env):
|
||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||
"""
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
self.obs_handler = ObservationsHandler.from_config(self,
|
||||
self.obs_config)
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
@@ -727,8 +733,7 @@ class Primaite(Env):
|
||||
_LOGGER.error(f"Invalid item_type: {item_type}")
|
||||
pass
|
||||
|
||||
_LOGGER.info("Environment configuration loaded")
|
||||
print("Environment configuration loaded")
|
||||
_LOGGER.debug("Environment configuration loaded")
|
||||
|
||||
def create_node(self, item):
|
||||
"""
|
||||
@@ -791,7 +796,8 @@ class Primaite(Env):
|
||||
service_protocol = service["name"]
|
||||
service_port = service["port"]
|
||||
service_state = SoftwareState[service["state"]]
|
||||
node.add_service(Service(service_protocol, service_port, service_state))
|
||||
node.add_service(
|
||||
Service(service_protocol, service_port, service_state))
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
@@ -844,7 +850,8 @@ class Primaite(Env):
|
||||
dest_node_ref: Node = self.nodes_reference[link_destination]
|
||||
|
||||
# Add link to network (reference)
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref,
|
||||
id=link_name)
|
||||
|
||||
# Add link to link dictionary (reference)
|
||||
self.links_reference[link_name] = Link(
|
||||
@@ -1119,7 +1126,8 @@ class Primaite(Env):
|
||||
# All nodes have these parameters
|
||||
node_id = item["node_id"]
|
||||
node_class = item["node_class"]
|
||||
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
|
||||
node_hardware_state: HardwareState = HardwareState[
|
||||
item["hardware_state"]]
|
||||
|
||||
node: NodeUnion = self.nodes[node_id]
|
||||
node_ref = self.nodes_reference[node_id]
|
||||
@@ -1185,7 +1193,8 @@ class Primaite(Env):
|
||||
# Use MAX to ensure we get them all
|
||||
for node_action in range(4):
|
||||
for service_state in range(self.num_services):
|
||||
action = [node, node_property, node_action, service_state]
|
||||
action = [node, node_property, node_action,
|
||||
service_state]
|
||||
# check to see if it's a nothing action (has no effect)
|
||||
if is_valid_node_action(action):
|
||||
actions[action_key] = action
|
||||
|
||||
@@ -47,8 +47,7 @@ class PrimaiteSession:
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
auto: bool = True
|
||||
lay_down_config_path: Union[str, Path]
|
||||
):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
@@ -64,13 +63,8 @@ class PrimaiteSession:
|
||||
self._lay_down_config_path
|
||||
)
|
||||
|
||||
self._auto: bool = auto
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
|
||||
if self._auto:
|
||||
self.setup()
|
||||
self.learn()
|
||||
|
||||
def setup(self):
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
@@ -157,7 +151,7 @@ class PrimaiteSession:
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
if not self._training_config.session_type == SessionType.EVALUATION:
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
|
||||
def evaluate(
|
||||
@@ -166,5 +160,5 @@ class PrimaiteSession:
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs
|
||||
):
|
||||
if not self._training_config.session_type == SessionType.TRAINING:
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
|
||||
@@ -2,4 +2,9 @@
|
||||
|
||||
# Logging
|
||||
log_level: INFO
|
||||
logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
logger_format:
|
||||
DEBUG: '%(asctime)s: %(message)s'
|
||||
INFO: '%(asctime)s: %(message)s'
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
|
||||
Reference in New Issue
Block a user