#1386: added documentation + dealing with pre-commit checks

This commit is contained in:
Czar Echavez
2023-06-20 11:19:05 +01:00
parent 9fb30ffe1b
commit 99399cbda6
17 changed files with 311 additions and 192 deletions

View File

@@ -31,7 +31,7 @@ def _get_primaite_config():
"INFO": logging.INFO,
"WARN": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
return primaite_config

View File

@@ -3,8 +3,8 @@
import logging
import os
import shutil
from pathlib import Path
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
:param last_n: The number of lines to print. Default value is 10.
"""
import re
from primaite import LOG_PATH
if os.path.isfile(LOG_PATH):
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
print(re.sub(r"\n*", "", line))
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
@app.command()

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Union, Optional
from typing import Any, Dict, Final, Optional, Union
import yaml
@@ -173,8 +173,7 @@ def main_training_config_path() -> Path:
return path
def load(file_path: Union[str, Path],
legacy_file: bool = False) -> TrainingConfig:
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
@@ -219,9 +218,7 @@ def load(file_path: Union[str, Path],
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
@@ -233,10 +230,7 @@ def convert_legacy_training_config_dict(
don't have action_type values.
:return: The converted training config dict.
"""
config_dict = {
"num_steps": num_steps,
"action_type": action_type
}
config_dict = {"num_steps": num_steps, "action_type": action_type}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:

View File

@@ -14,8 +14,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -24,8 +23,9 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState, ObservationType,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config import training_config
@@ -35,17 +35,14 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import \
NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
@@ -178,7 +175,6 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -200,7 +196,6 @@ class Primaite(Env):
try:
plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True)
now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
@@ -222,7 +217,9 @@ class Primaite(Env):
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
self.action_space = spaces.Discrete(
len(self.action_dict), seed=self.training_config.seed
)
elif self.training_config.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
@@ -233,13 +230,19 @@ class Primaite(Env):
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
self.action_space = spaces.Discrete(
len(self.action_dict), seed=self.training_config.seed
)
elif self.training_config.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict), seed=self.training_config.seed)
self.action_space = spaces.Discrete(
len(self.action_dict), seed=self.training_config.seed
)
else:
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
_LOGGER.info(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
@@ -380,7 +383,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
#print(f" Step {self.step_count} Reward: {str(reward)}")
# print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -407,12 +410,11 @@ class Primaite(Env):
return self.env_obs, reward, done, self.step_info
def close(self):
"""Calls the __close__ method."""
self.__close__()
def __close__(self):
"""
Override close function
"""
"""Override close function."""
self.csv_file.close()
def init_acl(self):
@@ -1039,7 +1041,6 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
"""
Extracts action_info.

View File

@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = getLogger(__name__)
@@ -87,7 +86,13 @@ def run_stable_baselines3_ppo(
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
agent = PPO(
PPOMlp,
env,
verbose=0,
n_steps=config_values.num_steps,
seed=env.training_config.seed,
)
if config_values.session_type == "TRAINING":
# We're in a training session
@@ -106,8 +111,7 @@ def run_stable_baselines3_ppo(
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs,
deterministic=env.training_config.deterministic
obs, deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
@@ -146,7 +150,13 @@ def run_stable_baselines3_a2c(
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps, seed=env.training_config.seed)
agent = A2C(
"MlpPolicy",
env,
verbose=0,
n_steps=config_values.num_steps,
seed=env.training_config.seed,
)
if config_values.session_type == "TRAINING":
# We're in a training session
@@ -164,8 +174,7 @@ def run_stable_baselines3_a2c(
for step in range(0, config_values.num_steps):
action, _states = agent.predict(
obs,
deterministic=env.training_config.deterministic
obs, deterministic=env.training_config.deterministic
)
# convert to int if action is a numpy array
if isinstance(action, np.ndarray):
@@ -368,5 +377,3 @@ if __name__ == "__main__":
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -46,6 +46,7 @@ class Node:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
@@ -64,14 +65,14 @@ class Node:
self.hardware_state = HardwareState.ON
def update_booting_status(self):
"""Updates the booting count"""
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
"""Updates the shutdown count"""
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:
self.shutting_down_count = 0

View File

@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
service_value.reduce_patching_count()
def update_resetting_status(self):
"""Updates the resetting counter for any service that are resetting."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
"""Updates the booting counter for any service that are booting up."""
super().update_booting_status()
if self.booting_count <= 0:
for service in self.services.values():
service.software_state =SoftwareState.GOOD
service.software_state = SoftwareState.GOOD

View File

@@ -17,7 +17,6 @@ def start_jupyter_session():
.. todo:: Figure out how to get this working for Linux and MacOS too.
"""
if importlib.util.find_spec("jupyter") is not None:
jupyter_cmd = "python3 -m jupyter lab"
if sys.platform == "win32":