#1386: added documentation + dealing with pre-commit checks

This commit is contained in:
Czar Echavez
2023-06-20 11:19:05 +01:00
parent 0ab4520904
commit db67a829d5
17 changed files with 311 additions and 192 deletions

View File

@@ -3,7 +3,7 @@ import tempfile
import time
from datetime import datetime
from pathlib import Path
from typing import Union, Final
from typing import Final, Union
import pandas as pd
@@ -30,7 +30,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
def _get_primaite_env_from_config(
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
@@ -84,7 +84,7 @@ def run_generic(env, config_values):
def compare_file_content(output_a_file_path: str, output_b_file_path: str):
"""Function used to check if output of both given files are the same"""
"""Function used to check if output of both given files are the same."""
with open(output_a_file_path) as f1:
with open(output_b_file_path) as f2:
f1_content = f1.read()
@@ -95,13 +95,15 @@ def compare_file_content(output_a_file_path: str, output_b_file_path: str):
# both files have the same content
return True
# both files have different content
print(f"{output_a_file_path} and {output_b_file_path} has different contents")
print(
f"{output_a_file_path} and {output_b_file_path} has different contents"
)
return False
def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
"""Function used to check if contents of transaction files are the same"""
"""Function used to check if contents of transaction files are the same."""
# load output a file
data_a = pd.read_csv(output_a_file_path)
@@ -109,19 +111,17 @@ def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
data_b = pd.read_csv(output_b_file_path)
# remove the time stamp column
data_a.drop('Timestamp', inplace=True, axis=1)
data_b.drop('Timestamp', inplace=True, axis=1)
data_a.drop("Timestamp", inplace=True, axis=1)
data_b.drop("Timestamp", inplace=True, axis=1)
# if the comparison is empty, both files are the same i.e. True
return data_a.compare(data_b).empty
class TestSession:
def __init__(
self,
training_config_path,
laydown_config_path
):
"""Class that contains session values."""
def __init__(self, training_config_path, laydown_config_path):
self.session_timestamp: Final[datetime] = datetime.now()
self.session_dir = _get_session_path(self.session_timestamp)
self.timestamp_str = self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
@@ -140,5 +140,8 @@ class TestSession:
print("Writing Session Metadata file...")
_write_session_metadata_file(
session_dir=self.session_dir, uuid="test", session_timestamp=self.session_timestamp, env=self.env
session_dir=self.session_dir,
uuid="test",
session_timestamp=self.session_timestamp,
env=self.env,
)

View File

@@ -2,8 +2,13 @@ import time
from primaite import getLogger
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.main import run_stable_baselines3_a2c, \
run_stable_baselines3_ppo, run_generic, _update_session_metadata_file, _get_session_path
from primaite.main import (
_get_session_path,
_update_session_metadata_file,
run_generic,
run_stable_baselines3_a2c,
run_stable_baselines3_ppo,
)
from primaite.transactions.transactions_to_file import write_transaction_to_file
from tests import TEST_CONFIG_ROOT
from tests.conftest import TestSession, compare_file_content, compare_transaction_file
@@ -12,18 +17,17 @@ _LOGGER = getLogger(__name__)
def test_generic_same_results():
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same"""
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same."""
print("")
print("=======================")
print("Generic test run")
print("=======================")
print("")
# run session 1
session1 = TestSession(
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
config_values = session1.env.training_config
@@ -31,17 +35,14 @@ def test_generic_same_results():
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = session1.env.episode_steps
run_generic(
env=session1.env,
config_values=session1.env.training_config
)
run_generic(env=session1.env, config_values=session1.env.training_config)
_update_session_metadata_file(session_dir=session1.session_dir, env=session1.env)
# run session 2
session2 = TestSession(
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
config_values = session2.env.training_config
@@ -49,10 +50,7 @@ def test_generic_same_results():
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = session2.env.episode_steps
run_generic(
env=session2.env,
config_values=session2.env.training_config
)
run_generic(env=session2.env, config_values=session2.env.training_config)
_update_session_metadata_file(session_dir=session2.session_dir, env=session2.env)
@@ -61,36 +59,41 @@ def test_generic_same_results():
time.sleep(1)
# check if both outputs are the same
assert compare_file_content(
session1.env.csv_file.name,
session2.env.csv_file.name,
) is True
assert (
compare_file_content(
session1.env.csv_file.name,
session2.env.csv_file.name,
)
is True
)
# deterministic run
deterministic = TestSession(
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
deterministic.env.training_config.deterministic = True
run_generic(
env=deterministic.env,
config_values=deterministic.env.training_config
run_generic(env=deterministic.env, config_values=deterministic.env.training_config)
_update_session_metadata_file(
session_dir=deterministic.session_dir, env=deterministic.env
)
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
# check if both outputs are the same
assert compare_file_content(
deterministic.env.csv_file.name,
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_generic.csv",
) is True
assert (
compare_file_content(
deterministic.env.csv_file.name,
TEST_CONFIG_ROOT
/ "e2e/deterministic_test_outputs/deterministic_generic.csv",
)
is True
)
def test_ppo_same_results():
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same"""
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same."""
print("")
print("=======================")
print("PPO test run")
@@ -99,7 +102,7 @@ def test_ppo_same_results():
training_session = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Train agent
@@ -123,12 +126,14 @@ def test_ppo_same_results():
timestamp_str=training_session.timestamp_str,
)
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
_update_session_metadata_file(
session_dir=training_session.session_dir, env=training_session.env
)
# Evaluate Agent again
eval_session1 = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -137,7 +142,10 @@ def test_ppo_same_results():
# load the agent that was trained previously
eval_session1.env.training_config.load_agent = True
eval_session1.env.training_config.agent_load_file = _get_session_path(training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session1.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session1.env.training_config
@@ -154,11 +162,13 @@ def test_ppo_same_results():
timestamp_str=eval_session1.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
_update_session_metadata_file(
session_dir=eval_session1.session_dir, env=eval_session1.env
)
eval_session2 = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -167,8 +177,10 @@ def test_ppo_same_results():
# load the agent that was trained previously
eval_session2.env.training_config.load_agent = True
eval_session2.env.training_config.agent_load_file = _get_session_path(
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session2.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session2.env.training_config
@@ -185,18 +197,25 @@ def test_ppo_same_results():
timestamp_str=eval_session2.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
_update_session_metadata_file(
session_dir=eval_session2.session_dir, env=eval_session2.env
)
# check if both eval outputs are the same
assert compare_transaction_file(
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
) is True
assert (
compare_transaction_file(
eval_session1.session_dir
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
)
is True
)
# deterministic run
deterministic = TestSession(
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
deterministic.env.training_config.deterministic = True
@@ -214,18 +233,23 @@ def test_ppo_same_results():
timestamp_str=deterministic.timestamp_str,
)
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
_update_session_metadata_file(
session_dir=deterministic.session_dir, env=deterministic.env
)
# check if both outputs are the same
assert compare_transaction_file(
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
) is True
assert (
compare_transaction_file(
deterministic.session_dir
/ f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
)
is True
)
def test_a2c_same_results():
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same"""
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same."""
print("")
print("=======================")
print("A2C test run")
@@ -234,7 +258,7 @@ def test_a2c_same_results():
training_session = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Train agent
@@ -258,12 +282,14 @@ def test_a2c_same_results():
timestamp_str=training_session.timestamp_str,
)
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
_update_session_metadata_file(
session_dir=training_session.session_dir, env=training_session.env
)
# Evaluate Agent again
eval_session1 = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -272,8 +298,10 @@ def test_a2c_same_results():
# load the agent that was trained previously
eval_session1.env.training_config.load_agent = True
eval_session1.env.training_config.agent_load_file = _get_session_path(
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session1.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session1.env.training_config
@@ -290,11 +318,13 @@ def test_a2c_same_results():
timestamp_str=eval_session1.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
_update_session_metadata_file(
session_dir=eval_session1.session_dir, env=eval_session1.env
)
eval_session2 = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
# Get the number of steps (which is stored in the child config file)
@@ -303,8 +333,10 @@ def test_a2c_same_results():
# load the agent that was trained previously
eval_session2.env.training_config.load_agent = True
eval_session2.env.training_config.agent_load_file = _get_session_path(
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
eval_session2.env.training_config.agent_load_file = (
_get_session_path(training_session.session_timestamp)
/ f"agent_saved_{training_session.timestamp_str}.zip"
)
config_values = eval_session2.env.training_config
@@ -321,18 +353,25 @@ def test_a2c_same_results():
timestamp_str=eval_session2.timestamp_str,
)
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
_update_session_metadata_file(
session_dir=eval_session2.session_dir, env=eval_session2.env
)
# check if both eval outputs are the same
assert compare_transaction_file(
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
) is True
assert (
compare_transaction_file(
eval_session1.session_dir
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
eval_session2.session_dir
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
)
is True
)
# deterministic run
deterministic = TestSession(
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
data_manipulation_config_path()
data_manipulation_config_path(),
)
deterministic.env.training_config.deterministic = True
@@ -350,10 +389,16 @@ def test_a2c_same_results():
timestamp_str=deterministic.timestamp_str,
)
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
_update_session_metadata_file(
session_dir=deterministic.session_dir, env=deterministic.env
)
# check if both outputs are the same
assert compare_transaction_file(
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
) is True
assert (
compare_transaction_file(
deterministic.session_dir
/ f"all_transactions_{deterministic.timestamp_str}.csv",
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
)
is True
)

View File

@@ -27,7 +27,8 @@ def env(request):
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
@@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite):
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
@@ -140,7 +142,8 @@ class TestNodeLinkTable:
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_STATUSES.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)

View File

@@ -1,7 +1,13 @@
"""Used to test Active Node functions."""
import pytest
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode
@pytest.mark.parametrize(
"starting_operating_state, expected_operating_state",
[
(HardwareState.RESETTING, HardwareState.ON)
],
[(HardwareState.RESETTING, HardwareState.ON)],
)
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
"""
Tests that a node resets correctly.
"""
"""Tests that a node resets correctly."""
active_node = ActiveNode(
node_id = "0",
name = "node",
node_type = NodeType.COMPUTER,
priority = Priority.P1,
hardware_state = starting_operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.COMPROMISED,
file_system_state = FileSystemState.CORRUPT,
config_values=TrainingConfig()
node_id="0",
name="node",
node_type=NodeType.COMPUTER,
priority=Priority.P1,
hardware_state=starting_operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.COMPROMISED,
file_system_state=FileSystemState.CORRUPT,
config_values=TrainingConfig(),
)
for x in range(5):
@@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat
assert active_node.file_system_state_actual == FileSystemState.GOOD
assert active_node.hardware_state == expected_operating_state
@pytest.mark.parametrize(
"operating_state, expected_operating_state",
[
(HardwareState.BOOTING, HardwareState.ON)
],
[(HardwareState.BOOTING, HardwareState.ON)],
)
def test_node_boots_correctly(operating_state, expected_operating_state):
"""
Tests that a node boots correctly.
"""
"""Tests that a node boots correctly."""
service_node = ServiceNode(
node_id = 0,
name = "node",
node_type = "COMPUTER",
priority = "1",
hardware_state = operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.GOOD,
file_system_state = "GOOD",
config_values = 1,
node_id=0,
name="node",
node_type="COMPUTER",
priority="1",
hardware_state=operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.GOOD,
file_system_state="GOOD",
config_values=1,
)
service_attributes = Service(
name = "node",
port = "80",
software_state = SoftwareState.COMPROMISED
)
service_node.add_service(
service_attributes
name="node", port="80", software_state=SoftwareState.COMPROMISED
)
service_node.add_service(service_attributes)
for x in range(5):
service_node.update_booting_status()
@@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
assert service_attributes.software_state == SoftwareState.GOOD
assert service_node.hardware_state == expected_operating_state
@pytest.mark.parametrize(
"operating_state, expected_operating_state",
[
(HardwareState.SHUTTING_DOWN, HardwareState.OFF)
],
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
)
def test_node_shutdown_correctly(operating_state, expected_operating_state):
"""
Tests that a node shutdown correctly.
"""
"""Tests that a node shutdown correctly."""
active_node = ActiveNode(
node_id = 0,
name = "node",
node_type = "COMPUTER",
priority = "1",
hardware_state = operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.GOOD,
file_system_state = "GOOD",
config_values = 1,
node_id=0,
name="node",
node_type="COMPUTER",
priority="1",
hardware_state=operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.GOOD,
file_system_state="GOOD",
config_values=1,
)
for x in range(5):
active_node.update_shutdown_status()
assert active_node.hardware_state == expected_operating_state

View File

@@ -48,7 +48,8 @@ def test_single_action_space_is_valid():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
run_generic_set_actions(env)
@@ -77,8 +78,10 @@ def test_single_action_space_is_valid():
def test_agent_is_executing_actions_from_both_spaces():
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)