#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -57,8 +57,6 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
del self._agent_session._env.episode_av_reward_writer
|
||||
del self._agent_session._env.transaction_writer
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
@@ -112,9 +110,7 @@ def temp_primaite_session(request):
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch(
|
||||
"primaite.agents.agent.get_session_path", get_temp_session_path
|
||||
) as mck:
|
||||
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
@@ -130,9 +126,7 @@ def temp_session_path() -> Path:
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = (
|
||||
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
)
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
@@ -16,9 +16,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = (
|
||||
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
)
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created temp session directory: {session_path}")
|
||||
return session_path
|
||||
|
||||
@@ -95,8 +95,6 @@ def test_rule_hash():
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash(
|
||||
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
|
||||
)
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
@@ -60,9 +60,7 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state):
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_software_state_if_not_compromised(
|
||||
SoftwareState.OVERWHELMED
|
||||
)
|
||||
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
|
||||
|
||||
assert active_node.software_state == expected_state
|
||||
|
||||
@@ -100,9 +98,7 @@ def test_file_system_change(operating_state, expected_state):
|
||||
(HardwareState.ON, FileSystemState.CORRUPT),
|
||||
],
|
||||
)
|
||||
def test_file_system_change_if_not_compromised(
|
||||
operating_state, expected_state
|
||||
):
|
||||
def test_file_system_change_if_not_compromised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change its file system state.
|
||||
|
||||
@@ -120,8 +116,6 @@ def test_file_system_change_if_not_compromised(
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_file_system_state_if_not_compromised(
|
||||
FileSystemState.CORRUPT
|
||||
)
|
||||
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
|
||||
|
||||
assert active_node.file_system_state_actual == expected_state
|
||||
|
||||
@@ -2,11 +2,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from primaite.environment.observations import (
|
||||
NodeLinkTable,
|
||||
NodeStatuses,
|
||||
ObservationsHandler,
|
||||
)
|
||||
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@@ -127,9 +123,7 @@ class TestNodeLinkTable:
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(
|
||||
0
|
||||
) # apply the 'do nothing' action
|
||||
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
@@ -192,17 +186,15 @@ class TestNodeStatuses:
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
assert np.array_equal(
|
||||
obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]
|
||||
)
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
|
||||
@@ -36,18 +36,12 @@ def test_primaite_session(temp_primaite_session):
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.evaluation_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -18,9 +12,7 @@ from primaite.nodes.service_node import ServiceNode
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(
|
||||
starting_operating_state, expected_operating_state
|
||||
):
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id="0",
|
||||
@@ -59,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(
|
||||
name="node", port="80", software_state=SoftwareState.COMPROMISED
|
||||
)
|
||||
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
|
||||
@@ -45,9 +45,7 @@ def test_service_state_change(operating_state, expected_state):
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_service_state_change_if_not_comprised(
|
||||
operating_state, expected_state
|
||||
):
|
||||
def test_service_state_change_if_not_comprised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change the state of a running service.
|
||||
|
||||
@@ -67,8 +65,6 @@ def test_service_state_change_if_not_comprised(
|
||||
service = Service("TCP", 80, SoftwareState.GOOD)
|
||||
service_node.add_service(service)
|
||||
|
||||
service_node.set_service_state_if_not_compromised(
|
||||
"TCP", SoftwareState.OVERWHELMED
|
||||
)
|
||||
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
|
||||
|
||||
assert service_node.get_service_state("TCP") == expected_state
|
||||
|
||||
@@ -18,7 +18,6 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# Creates an ACL rule
|
||||
@@ -86,8 +85,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
|
||||
@@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT
|
||||
|
||||
def test_legacy_lay_down_config_yaml_conversion():
|
||||
"""Tests the conversion of legacy lay down config files."""
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
with open(legacy_path, "r") as file:
|
||||
legacy_dict = yaml.safe_load(file)
|
||||
@@ -16,9 +16,7 @@ def test_legacy_lay_down_config_yaml_conversion():
|
||||
with open(new_path, "r") as file:
|
||||
new_dict = yaml.safe_load(file)
|
||||
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(
|
||||
legacy_dict
|
||||
)
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
|
||||
|
||||
for key, value in new_dict.items():
|
||||
assert converted_dict[key] == value
|
||||
@@ -26,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion():
|
||||
|
||||
def test_create_config_values_main_from_file():
|
||||
"""Tests creating an instance of TrainingConfig from file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
training_config.load(new_path)
|
||||
|
||||
|
||||
def test_create_config_values_main_from_legacy_file():
|
||||
"""Tests creating an instance of TrainingConfig from legacy file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
|
||||
training_config.load(new_path, legacy_file=True)
|
||||
|
||||
Reference in New Issue
Block a user