#1386: added documentation + dealing with pre-commit checks
This commit is contained in:
@@ -3,7 +3,7 @@ import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union, Final
|
||||
from typing import Final, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -30,7 +30,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
|
||||
|
||||
def _get_primaite_env_from_config(
|
||||
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
):
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
@@ -84,7 +84,7 @@ def run_generic(env, config_values):
|
||||
|
||||
|
||||
def compare_file_content(output_a_file_path: str, output_b_file_path: str):
|
||||
"""Function used to check if output of both given files are the same"""
|
||||
"""Function used to check if output of both given files are the same."""
|
||||
with open(output_a_file_path) as f1:
|
||||
with open(output_b_file_path) as f2:
|
||||
f1_content = f1.read()
|
||||
@@ -95,13 +95,15 @@ def compare_file_content(output_a_file_path: str, output_b_file_path: str):
|
||||
# both files have the same content
|
||||
return True
|
||||
# both files have different content
|
||||
print(f"{output_a_file_path} and {output_b_file_path} has different contents")
|
||||
print(
|
||||
f"{output_a_file_path} and {output_b_file_path} has different contents"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
|
||||
"""Function used to check if contents of transaction files are the same"""
|
||||
"""Function used to check if contents of transaction files are the same."""
|
||||
# load output a file
|
||||
data_a = pd.read_csv(output_a_file_path)
|
||||
|
||||
@@ -109,19 +111,17 @@ def compare_transaction_file(output_a_file_path: str, output_b_file_path: str):
|
||||
data_b = pd.read_csv(output_b_file_path)
|
||||
|
||||
# remove the time stamp column
|
||||
data_a.drop('Timestamp', inplace=True, axis=1)
|
||||
data_b.drop('Timestamp', inplace=True, axis=1)
|
||||
data_a.drop("Timestamp", inplace=True, axis=1)
|
||||
data_b.drop("Timestamp", inplace=True, axis=1)
|
||||
|
||||
# if the comparison is empty, both files are the same i.e. True
|
||||
return data_a.compare(data_b).empty
|
||||
|
||||
|
||||
class TestSession:
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
laydown_config_path
|
||||
):
|
||||
"""Class that contains session values."""
|
||||
|
||||
def __init__(self, training_config_path, laydown_config_path):
|
||||
self.session_timestamp: Final[datetime] = datetime.now()
|
||||
self.session_dir = _get_session_path(self.session_timestamp)
|
||||
self.timestamp_str = self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
@@ -140,5 +140,8 @@ class TestSession:
|
||||
print("Writing Session Metadata file...")
|
||||
|
||||
_write_session_metadata_file(
|
||||
session_dir=self.session_dir, uuid="test", session_timestamp=self.session_timestamp, env=self.env
|
||||
session_dir=self.session_dir,
|
||||
uuid="test",
|
||||
session_timestamp=self.session_timestamp,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,13 @@ import time
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.main import run_stable_baselines3_a2c, \
|
||||
run_stable_baselines3_ppo, run_generic, _update_session_metadata_file, _get_session_path
|
||||
from primaite.main import (
|
||||
_get_session_path,
|
||||
_update_session_metadata_file,
|
||||
run_generic,
|
||||
run_stable_baselines3_a2c,
|
||||
run_stable_baselines3_ppo,
|
||||
)
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import TestSession, compare_file_content, compare_transaction_file
|
||||
@@ -12,18 +17,17 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def test_generic_same_results():
|
||||
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same"""
|
||||
"""Runs seeded and deterministic Generic Primaite sessions and checks that the results are the same."""
|
||||
print("")
|
||||
print("=======================")
|
||||
print("Generic test run")
|
||||
print("=======================")
|
||||
print("")
|
||||
|
||||
|
||||
# run session 1
|
||||
session1 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
config_values = session1.env.training_config
|
||||
@@ -31,17 +35,14 @@ def test_generic_same_results():
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = session1.env.episode_steps
|
||||
|
||||
run_generic(
|
||||
env=session1.env,
|
||||
config_values=session1.env.training_config
|
||||
)
|
||||
run_generic(env=session1.env, config_values=session1.env.training_config)
|
||||
|
||||
_update_session_metadata_file(session_dir=session1.session_dir, env=session1.env)
|
||||
|
||||
# run session 2
|
||||
session2 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
config_values = session2.env.training_config
|
||||
@@ -49,10 +50,7 @@ def test_generic_same_results():
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
config_values.num_steps = session2.env.episode_steps
|
||||
|
||||
run_generic(
|
||||
env=session2.env,
|
||||
config_values=session2.env.training_config
|
||||
)
|
||||
run_generic(env=session2.env, config_values=session2.env.training_config)
|
||||
|
||||
_update_session_metadata_file(session_dir=session2.session_dir, env=session2.env)
|
||||
|
||||
@@ -61,36 +59,41 @@ def test_generic_same_results():
|
||||
time.sleep(1)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_file_content(
|
||||
session1.env.csv_file.name,
|
||||
session2.env.csv_file.name,
|
||||
) is True
|
||||
assert (
|
||||
compare_file_content(
|
||||
session1.env.csv_file.name,
|
||||
session2.env.csv_file.name,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# deterministic run
|
||||
deterministic = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/generic_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
deterministic.env.training_config.deterministic = True
|
||||
|
||||
run_generic(
|
||||
env=deterministic.env,
|
||||
config_values=deterministic.env.training_config
|
||||
run_generic(env=deterministic.env, config_values=deterministic.env.training_config)
|
||||
|
||||
_update_session_metadata_file(
|
||||
session_dir=deterministic.session_dir, env=deterministic.env
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_file_content(
|
||||
deterministic.env.csv_file.name,
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_generic.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_file_content(
|
||||
deterministic.env.csv_file.name,
|
||||
TEST_CONFIG_ROOT
|
||||
/ "e2e/deterministic_test_outputs/deterministic_generic.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_ppo_same_results():
|
||||
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same"""
|
||||
|
||||
"""Runs seeded and deterministic PPO Primaite sessions and checks that the results are the same."""
|
||||
print("")
|
||||
print("=======================")
|
||||
print("PPO test run")
|
||||
@@ -99,7 +102,7 @@ def test_ppo_same_results():
|
||||
|
||||
training_session = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Train agent
|
||||
@@ -123,12 +126,14 @@ def test_ppo_same_results():
|
||||
timestamp_str=training_session.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=training_session.session_dir, env=training_session.env
|
||||
)
|
||||
|
||||
# Evaluate Agent again
|
||||
eval_session1 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -137,7 +142,10 @@ def test_ppo_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session1.env.training_config.load_agent = True
|
||||
eval_session1.env.training_config.agent_load_file = _get_session_path(training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session1.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session1.env.training_config
|
||||
|
||||
@@ -154,11 +162,13 @@ def test_ppo_same_results():
|
||||
timestamp_str=eval_session1.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session1.session_dir, env=eval_session1.env
|
||||
)
|
||||
|
||||
eval_session2 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -167,8 +177,10 @@ def test_ppo_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session2.env.training_config.load_agent = True
|
||||
eval_session2.env.training_config.agent_load_file = _get_session_path(
|
||||
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session2.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session2.env.training_config
|
||||
|
||||
@@ -185,18 +197,25 @@ def test_ppo_same_results():
|
||||
timestamp_str=eval_session2.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session2.session_dir, env=eval_session2.env
|
||||
)
|
||||
|
||||
# check if both eval outputs are the same
|
||||
assert compare_transaction_file(
|
||||
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
eval_session1.session_dir
|
||||
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir
|
||||
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# deterministic run
|
||||
deterministic = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/ppo_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
deterministic.env.training_config.deterministic = True
|
||||
@@ -214,18 +233,23 @@ def test_ppo_same_results():
|
||||
timestamp_str=deterministic.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=deterministic.session_dir, env=deterministic.env
|
||||
)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_transaction_file(
|
||||
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
deterministic.session_dir
|
||||
/ f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_ppo.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_a2c_same_results():
|
||||
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same"""
|
||||
|
||||
"""Runs seeded and deterministic A2C Primaite sessions and checks that the results are the same."""
|
||||
print("")
|
||||
print("=======================")
|
||||
print("A2C test run")
|
||||
@@ -234,7 +258,7 @@ def test_a2c_same_results():
|
||||
|
||||
training_session = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Train agent
|
||||
@@ -258,12 +282,14 @@ def test_a2c_same_results():
|
||||
timestamp_str=training_session.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=training_session.session_dir, env=training_session.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=training_session.session_dir, env=training_session.env
|
||||
)
|
||||
|
||||
# Evaluate Agent again
|
||||
eval_session1 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -272,8 +298,10 @@ def test_a2c_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session1.env.training_config.load_agent = True
|
||||
eval_session1.env.training_config.agent_load_file = _get_session_path(
|
||||
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session1.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session1.env.training_config
|
||||
|
||||
@@ -290,11 +318,13 @@ def test_a2c_same_results():
|
||||
timestamp_str=eval_session1.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session1.session_dir, env=eval_session1.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session1.session_dir, env=eval_session1.env
|
||||
)
|
||||
|
||||
eval_session2 = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
# Get the number of steps (which is stored in the child config file)
|
||||
@@ -303,8 +333,10 @@ def test_a2c_same_results():
|
||||
|
||||
# load the agent that was trained previously
|
||||
eval_session2.env.training_config.load_agent = True
|
||||
eval_session2.env.training_config.agent_load_file = _get_session_path(
|
||||
training_session.session_timestamp) / f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
eval_session2.env.training_config.agent_load_file = (
|
||||
_get_session_path(training_session.session_timestamp)
|
||||
/ f"agent_saved_{training_session.timestamp_str}.zip"
|
||||
)
|
||||
|
||||
config_values = eval_session2.env.training_config
|
||||
|
||||
@@ -321,18 +353,25 @@ def test_a2c_same_results():
|
||||
timestamp_str=eval_session2.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=eval_session2.session_dir, env=eval_session2.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=eval_session2.session_dir, env=eval_session2.env
|
||||
)
|
||||
|
||||
# check if both eval outputs are the same
|
||||
assert compare_transaction_file(
|
||||
eval_session1.session_dir / f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir / f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
eval_session1.session_dir
|
||||
/ f"all_transactions_{eval_session1.timestamp_str}.csv",
|
||||
eval_session2.session_dir
|
||||
/ f"all_transactions_{eval_session2.timestamp_str}.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# deterministic run
|
||||
deterministic = TestSession(
|
||||
TEST_CONFIG_ROOT / "e2e/a2c_deterministic_seeded_training_config.yaml",
|
||||
data_manipulation_config_path()
|
||||
data_manipulation_config_path(),
|
||||
)
|
||||
|
||||
deterministic.env.training_config.deterministic = True
|
||||
@@ -350,10 +389,16 @@ def test_a2c_same_results():
|
||||
timestamp_str=deterministic.timestamp_str,
|
||||
)
|
||||
|
||||
_update_session_metadata_file(session_dir=deterministic.session_dir, env=deterministic.env)
|
||||
_update_session_metadata_file(
|
||||
session_dir=deterministic.session_dir, env=deterministic.env
|
||||
)
|
||||
|
||||
# check if both outputs are the same
|
||||
assert compare_transaction_file(
|
||||
deterministic.session_dir / f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
|
||||
) is True
|
||||
assert (
|
||||
compare_transaction_file(
|
||||
deterministic.session_dir
|
||||
/ f"all_transactions_{deterministic.timestamp_str}.csv",
|
||||
TEST_CONFIG_ROOT / "e2e/deterministic_test_outputs/deterministic_a2c.csv",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@@ -27,7 +27,8 @@ def env(request):
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite):
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_without_obs.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
@@ -140,7 +142,8 @@ class TestNodeLinkTable:
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.RESETTING, HardwareState.ON)
|
||||
],
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node resets correctly.
|
||||
"""
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id = "0",
|
||||
name = "node",
|
||||
node_type = NodeType.COMPUTER,
|
||||
priority = Priority.P1,
|
||||
hardware_state = starting_operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.COMPROMISED,
|
||||
file_system_state = FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig()
|
||||
node_id="0",
|
||||
name="node",
|
||||
node_type=NodeType.COMPUTER,
|
||||
priority=Priority.P1,
|
||||
hardware_state=starting_operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.COMPROMISED,
|
||||
file_system_state=FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig(),
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
@@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat
|
||||
assert active_node.file_system_state_actual == FileSystemState.GOOD
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.BOOTING, HardwareState.ON)
|
||||
],
|
||||
[(HardwareState.BOOTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node boots correctly.
|
||||
"""
|
||||
"""Tests that a node boots correctly."""
|
||||
service_node = ServiceNode(
|
||||
node_id = 0,
|
||||
name = "node",
|
||||
node_type = "COMPUTER",
|
||||
priority = "1",
|
||||
hardware_state = operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.GOOD,
|
||||
file_system_state = "GOOD",
|
||||
config_values = 1,
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(
|
||||
name = "node",
|
||||
port = "80",
|
||||
software_state = SoftwareState.COMPROMISED
|
||||
)
|
||||
service_node.add_service(
|
||||
service_attributes
|
||||
name="node", port="80", software_state=SoftwareState.COMPROMISED
|
||||
)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
service_node.update_booting_status()
|
||||
@@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
assert service_attributes.software_state == SoftwareState.GOOD
|
||||
assert service_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.SHUTTING_DOWN, HardwareState.OFF)
|
||||
],
|
||||
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
|
||||
)
|
||||
def test_node_shutdown_correctly(operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node shutdown correctly.
|
||||
"""
|
||||
"""Tests that a node shutdown correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id = 0,
|
||||
name = "node",
|
||||
node_type = "COMPUTER",
|
||||
priority = "1",
|
||||
hardware_state = operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.GOOD,
|
||||
file_system_state = "GOOD",
|
||||
config_values = 1,
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
active_node.update_shutdown_status()
|
||||
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,8 @@ def test_single_action_space_is_valid():
|
||||
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
|
||||
run_generic_set_actions(env)
|
||||
@@ -77,8 +78,10 @@ def test_single_action_space_is_valid():
|
||||
def test_agent_is_executing_actions_from_both_spaces():
|
||||
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
|
||||
env = _get_primaite_env_from_config(
|
||||
training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "single_action_space_lay_down_config.yaml",
|
||||
)
|
||||
# Run environment with specified fixed blue agent actions only
|
||||
run_generic_set_actions(env)
|
||||
|
||||
Reference in New Issue
Block a user