#917 - Synced with dev and added better logging

This commit is contained in:
Chris McCarthy
2023-06-28 12:01:01 +01:00
parent 498e6a7ac1
commit 7482192046
12 changed files with 170 additions and 87 deletions

View File

@@ -2,9 +2,12 @@
import logging
import logging.config
import sys
from logging import Logger, StreamHandler
from bisect import bisect
from logging import Formatter, LogRecord, StreamHandler
from logging import Logger
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict
from typing import Final
import pkg_resources
@@ -68,6 +71,33 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``.
# region Setup Logging
class _LevelFormatter(Formatter):
"""
A custom level-specific formatter.
Credit to: https://stackoverflow.com/a/68154386
"""
def __init__(self, formats: Dict[int, str], **kwargs):
super().__init__()
if "fmt" in kwargs:
raise ValueError(
"Format string must be passed to level-surrogate formatters, "
"not this one"
)
self.formats = sorted(
(level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()
)
def format(self, record: LogRecord) -> str:
"""Overrides ``Formatter.format``."""
idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1)
level, formatter = self.formats[idx]
return formatter.format(record)
def _log_dir() -> Path:
if sys.platform == "win32":
dir_path = _PLATFORM_DIRS.user_data_path / "logs"
@@ -76,6 +106,16 @@ def _log_dir() -> Path:
return dir_path
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"]
}
)
LOG_DIR: Final[Path] = _log_dir()
"""The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS."""
@@ -85,6 +125,7 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log"
"""The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS."""
_STREAM_HANDLER: Final[StreamHandler] = StreamHandler()
_FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
filename=LOG_PATH,
maxBytes=10485760, # 10MB
@@ -95,8 +136,8 @@ _STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"])
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"]
_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR))
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
_LOGGER = logging.getLogger(__name__)

View File

@@ -3,11 +3,11 @@ import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Optional, Final, Dict, Union, List
from typing import Optional, Final, Dict, Union
from uuid import uuid4
import primaite
from primaite import getLogger, SESSIONS_DIR
from primaite.common.enums import OutputVerboseLevel
from primaite.config import lay_down_config
from primaite.config import training_config
from primaite.config.training_config import TrainingConfig
@@ -141,14 +141,13 @@ class AgentSessionABC(ABC):
@abstractmethod
def _setup(self):
if self.output_verbose_level >= OutputVerboseLevel.INFO:
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment "
"(PrimAITE)"
)
_LOGGER.debug(
f"The output directory for this agent is: {self.session_path}"
)
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment "
f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(
f"The output directory for this session is: {self.session_path}"
)
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@@ -165,6 +164,7 @@ class AgentSessionABC(ABC):
**kwargs
):
if self._can_learn:
_LOGGER.info("Finished learning")
_LOGGER.debug("Writing transactions")
self._update_session_metadata_file()
self._can_evaluate = True
@@ -176,7 +176,7 @@ class AgentSessionABC(ABC):
episodes: Optional[int] = None,
**kwargs
):
pass
_LOGGER.info("Finished evaluation")
@abstractmethod
def _get_latest_checkpoint(self):
@@ -260,6 +260,7 @@ class HardCodedAgentSessionABC(AgentSessionABC):
# Introduce a delay between steps
time.sleep(self._training_config.time_delay / 1000)
self._env.close()
super().evaluate()
@classmethod
def load(cls):

View File

@@ -152,7 +152,8 @@ class RLlibAgent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()

View File

@@ -69,8 +69,9 @@ class SB3Agent(AgentSessionABC):
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
):
self._agent.save(
self.checkpoints_path / f"sb3ppo_{episode_count}.zip")
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
@@ -86,6 +87,8 @@ class SB3Agent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
@@ -106,6 +109,9 @@ class SB3Agent(AgentSessionABC):
if not episodes:
episodes = self._training_config.num_episodes
_LOGGER.info(f"Beginning evaluation for {episodes} episodes @"
f" {time_steps} time steps...")
for episode in range(episodes):
obs = self._env.reset()

View File

@@ -6,8 +6,8 @@ from primaite.common.enums import (
NodeHardwareAction,
NodeSoftwareAction,
SoftwareState,
NodePOLType
)
from primaite.common.enums import NodePOLType
def transform_action_node_readable(action):

View File

@@ -160,13 +160,26 @@ def setup(overwrite_existing: bool = True):
@app.command()
def session(tc: str, ldc: str):
def session(tc: Optional[str] = None, ldc: Optional[str] = None):
"""
Run a PrimAITE session.
:param tc: The training config filepath.
:param ldc: The lay down config file path.
tc: The training config filepath. Optional. If no value is passed then
example default training config is used from:
~/primaite/config/example_config/training/training_config_main.yaml.
ldc: The lay down config file path. Optional. If no value is passed then
example default lay down config is used from:
~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml.
"""
from primaite.main import run
from primaite.config.training_config import main_training_config_path
from primaite.config.lay_down_config import data_manipulation_config_path
if not tc:
tc = main_training_config_path()
if not ldc:
ldc = data_manipulation_config_path()
run(training_config_path=tc, lay_down_config_path=ldc)

View File

@@ -83,9 +83,12 @@ class Protocol(Enum):
class SessionType(Enum):
"""The type of PrimAITE Session to be run."""
TRAINING = 1
EVALUATION = 2
BOTH = 3
TRAIN = 1
"Train an agent"
EVAL = 2
"Evaluate an agent"
TRAIN_EVAL = 3
"Train then evaluate an agent"
class VerboseLevel(IntEnum):
@@ -141,7 +144,6 @@ class HardCodedAgentView(Enum):
class ActionType(Enum):
"""Action type enumeration."""
NODE = 0
ACL = 1
ANY = 2
@@ -149,7 +151,6 @@ class ActionType(Enum):
class ObservationType(Enum):
"""Observation type enumeration."""
BOX = 0
MULTIDISCRETE = 1

View File

@@ -38,26 +38,23 @@ hard_coded_agent_view: FULL
action_type: ANY
# Number of episodes to run per session
num_episodes: 100
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 100
checkpoint_every_n_episodes: 10
# Time delay between steps (for generic agents)
time_delay: 3
time_delay: 5
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN
# Environment config values
# The high value for the observation space

View File

@@ -69,7 +69,7 @@ class TrainingConfig:
"The delay between steps (ms). Applies to generic agents only"
# file
session_type: SessionType = SessionType.TRAINING
session_type: SessionType = SessionType.TRAIN
"The type of PrimAITE session to run"
load_agent: str = False
@@ -171,6 +171,7 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system"
@classmethod
def from_dict(
cls,
@@ -183,7 +184,7 @@ class TrainingConfig:
"action_type": ActionType,
"session_type": SessionType,
"output_verbose_level": OutputVerboseLevel,
"hard_coded_agent_view": HardCodedAgentView
"hard_coded_agent_view": HardCodedAgentView,
}
for field, enum_class in field_enum_map.items():
@@ -211,6 +212,20 @@ class TrainingConfig:
return data
def __str__(self) -> str:
tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, "
if self.agent_framework is AgentFramework.RLLIB:
tc += f"deep_learning_framework=" \
f"{self.deep_learning_framework.name}, "
tc += f"agent_identifier={self.agent_identifier.name}, "
if self.agent_identifier is AgentIdentifier.HARDCODED:
tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, "
tc += f"action_type={self.action_type.name}, "
tc += f"observation_space={self.observation_space}, "
tc += f"num_episodes={self.num_episodes}, "
tc += f"num_steps={self.num_steps})"
return tc
def load(
file_path: Union[str, Path],

View File

@@ -15,7 +15,8 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -36,13 +37,15 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import \
NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
@@ -61,12 +64,12 @@ class Primaite(Env):
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
):
"""
The Primaite constructor.
@@ -86,6 +89,7 @@ class Primaite(Env):
self.training_config: TrainingConfig = training_config.load(
training_config_path
)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
self.episode_steps = self.training_config.num_steps
@@ -207,16 +211,14 @@ class Primaite(Env):
plt.savefig(file_path, format="PNG")
plt.clf()
except Exception:
_LOGGER.error("Could not save network diagram")
_LOGGER.error("Exception occured", exc_info=True)
print("Could not save network diagram")
_LOGGER.error("Could not save network diagram", exc_info=True)
# Initiate observation space
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
if self.training_config.action_type == ActionType.NODE:
_LOGGER.info("Action space type NODE selected")
_LOGGER.debug("Action space type NODE selected")
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
@@ -225,7 +227,7 @@ class Primaite(Env):
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.training_config.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
_LOGGER.debug("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
@@ -236,11 +238,11 @@ class Primaite(Env):
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.training_config.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
_LOGGER.debug("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info(
_LOGGER.error(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
@@ -301,17 +303,14 @@ class Primaite(Env):
done: Indicates episode is complete if True
step_info: Additional information relating to this step
"""
# Introduce a delay between steps
time.sleep(self.training_config.time_delay / 1000)
if self.step_count == 0:
print(f"Episode: {str(self.episode_count)}")
_LOGGER.info(f"Episode: {str(self.episode_count)}")
# TEMP
done = False
self.step_count += 1
self.total_step_count += 1
# print("Episode step: " + str(self.step_count))
# Need to clear traffic on all links first
for link_key, link_value in self.links.items():
@@ -322,7 +321,8 @@ class Primaite(Env):
# Create a Transaction (metric) object for this step
transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
datetime.now(), self.agent_identifier, self.episode_count,
self.step_count
)
# Load the initial observation space into the transaction
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
@@ -352,7 +352,8 @@ class Primaite(Env):
self.nodes_post_pol = copy.deepcopy(self.nodes)
self.links_post_pol = copy.deepcopy(self.links)
# Reference
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
apply_node_pol(self.nodes_reference, self.node_pol,
self.step_count) # Node PoL
apply_iers(
self.network_reference,
self.nodes_reference,
@@ -389,7 +390,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
# print(f" Step {self.step_count} Reward: {str(reward)}")
_LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -397,7 +398,7 @@ class Primaite(Env):
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
print(f" Average Reward: {str(self.average_reward)}")
_LOGGER.info(f" Average Reward: {str(self.average_reward)}")
# Load the reward into the transaction
transaction.set_reward(reward)
@@ -428,6 +429,7 @@ class Primaite(Env):
self.timestamp_str
)
self.csv_file.close()
def init_acl(self):
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
@@ -435,9 +437,9 @@ class Primaite(Env):
def output_link_status(self):
"""Output the link status of all links to the console."""
for link_key, link_value in self.links.items():
print("Link ID: " + link_value.get_id())
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
print(
_LOGGER.debug(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
@@ -457,11 +459,11 @@ class Primaite(Env):
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
@@ -529,7 +531,8 @@ class Primaite(Env):
elif property_action == 1:
# Patch (valid action if it's good or compromised)
node.set_service_state(
self.services_list[service_index], SoftwareState.PATCHING
self.services_list[service_index],
SoftwareState.PATCHING
)
else:
# Node is not of Service Type
@@ -589,7 +592,8 @@ class Primaite(Env):
acl_rule_source = "ANY"
else:
node = list(self.nodes.values())[action_source_ip - 1]
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
if isinstance(node, ServiceNode) or isinstance(node,
ActiveNode):
acl_rule_source = node.ip_address
else:
return
@@ -598,7 +602,8 @@ class Primaite(Env):
acl_rule_destination = "ANY"
else:
node = list(self.nodes.values())[action_destination_ip - 1]
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
if isinstance(node, ServiceNode) or isinstance(node,
ActiveNode):
acl_rule_destination = node.ip_address
else:
return
@@ -683,7 +688,8 @@ class Primaite(Env):
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
self.obs_handler = ObservationsHandler.from_config(self,
self.obs_config)
return self.obs_handler.space, self.obs_handler.current_observation
@@ -727,8 +733,7 @@ class Primaite(Env):
_LOGGER.error(f"Invalid item_type: {item_type}")
pass
_LOGGER.info("Environment configuration loaded")
print("Environment configuration loaded")
_LOGGER.debug("Environment configuration loaded")
def create_node(self, item):
"""
@@ -791,7 +796,8 @@ class Primaite(Env):
service_protocol = service["name"]
service_port = service["port"]
service_state = SoftwareState[service["state"]]
node.add_service(Service(service_protocol, service_port, service_state))
node.add_service(
Service(service_protocol, service_port, service_state))
else:
# Bad formatting
pass
@@ -844,7 +850,8 @@ class Primaite(Env):
dest_node_ref: Node = self.nodes_reference[link_destination]
# Add link to network (reference)
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
self.network_reference.add_edge(source_node_ref, dest_node_ref,
id=link_name)
# Add link to link dictionary (reference)
self.links_reference[link_name] = Link(
@@ -1119,7 +1126,8 @@ class Primaite(Env):
# All nodes have these parameters
node_id = item["node_id"]
node_class = item["node_class"]
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
node_hardware_state: HardwareState = HardwareState[
item["hardware_state"]]
node: NodeUnion = self.nodes[node_id]
node_ref = self.nodes_reference[node_id]
@@ -1185,7 +1193,8 @@ class Primaite(Env):
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action, service_state]
action = [node, node_property, node_action,
service_state]
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
actions[action_key] = action

View File

@@ -47,8 +47,7 @@ class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
auto: bool = True
lay_down_config_path: Union[str, Path]
):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
@@ -64,13 +63,8 @@ class PrimaiteSession:
self._lay_down_config_path
)
self._auto: bool = auto
self._agent_session: AgentSessionABC = None # noqa
if self._auto:
self.setup()
self.learn()
def setup(self):
if self._training_config.agent_framework == AgentFramework.CUSTOM:
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
@@ -157,7 +151,7 @@ class PrimaiteSession:
episodes: Optional[int] = None,
**kwargs
):
if not self._training_config.session_type == SessionType.EVALUATION:
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(time_steps, episodes, **kwargs)
def evaluate(
@@ -166,5 +160,5 @@ class PrimaiteSession:
episodes: Optional[int] = None,
**kwargs
):
if not self._training_config.session_type == SessionType.TRAINING:
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(time_steps, episodes, **kwargs)

View File

@@ -2,4 +2,9 @@
# Logging
log_level: INFO
logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
logger_format:
DEBUG: '%(asctime)s: %(message)s'
INFO: '%(asctime)s: %(message)s'
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'