#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -57,8 +57,6 @@ class TempPrimaiteSession(PrimaiteSession):
return self
def __exit__(self, type, value, tb):
del self._agent_session._env.episode_av_reward_writer
del self._agent_session._env.transaction_writer
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@@ -112,9 +110,7 @@ def temp_primaite_session(request):
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch(
"primaite.agents.agent.get_session_path", get_temp_session_path
) as mck:
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
@@ -130,9 +126,7 @@ def temp_session_path() -> Path:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = (
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
)
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path

View File

@@ -16,9 +16,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = (
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
)
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path

View File

@@ -95,8 +95,6 @@ def test_rule_hash():
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash(
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
assert hash_value_local == hash_value_remote

View File

@@ -60,9 +60,7 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state):
1,
)
active_node.set_software_state_if_not_compromised(
SoftwareState.OVERWHELMED
)
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
assert active_node.software_state == expected_state
@@ -100,9 +98,7 @@ def test_file_system_change(operating_state, expected_state):
(HardwareState.ON, FileSystemState.CORRUPT),
],
)
def test_file_system_change_if_not_compromised(
operating_state, expected_state
):
def test_file_system_change_if_not_compromised(operating_state, expected_state):
"""
Test that a node cannot change its file system state.
@@ -120,8 +116,6 @@ def test_file_system_change_if_not_compromised(
1,
)
active_node.set_file_system_state_if_not_compromised(
FileSystemState.CORRUPT
)
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
assert active_node.file_system_state_actual == expected_state

View File

@@ -2,11 +2,7 @@
import numpy as np
import pytest
from primaite.environment.observations import (
NodeLinkTable,
NodeStatuses,
ObservationsHandler,
)
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
from tests import TEST_CONFIG_ROOT
@@ -127,9 +123,7 @@ class TestNodeLinkTable:
with temp_primaite_session as session:
env = session.env
# act = np.asarray([0,])
obs, reward, done, info = env.step(
0
) # apply the 'do nothing' action
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs,
@@ -192,17 +186,15 @@ class TestNodeStatuses:
with temp_primaite_session as session:
env = session.env
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]
)
print(obs)
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],

View File

@@ -36,18 +36,12 @@ def test_primaite_session(temp_primaite_session):
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":
assert (
"all_transactions" in file.name
or "average_reward_per_episode" in file.name
)
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
# Check that both the transactions and av reward csv files exist
for file in session.evaluation_path.iterdir():
if file.suffix == ".csv":
assert (
"all_transactions" in file.name
or "average_reward_per_episode" in file.name
)
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
_LOGGER.debug("Inspecting files in temp session path...")
for dir_path, dir_names, file_names in os.walk(session_path):

View File

@@ -1,13 +1,7 @@
"""Used to test Active Node functions."""
import pytest
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -18,9 +12,7 @@ from primaite.nodes.service_node import ServiceNode
"starting_operating_state, expected_operating_state",
[(HardwareState.RESETTING, HardwareState.ON)],
)
def test_node_resets_correctly(
starting_operating_state, expected_operating_state
):
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
"""Tests that a node resets correctly."""
active_node = ActiveNode(
node_id="0",
@@ -59,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
file_system_state="GOOD",
config_values=1,
)
service_attributes = Service(
name="node", port="80", software_state=SoftwareState.COMPROMISED
)
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
service_node.add_service(service_attributes)
for x in range(5):

View File

@@ -45,9 +45,7 @@ def test_service_state_change(operating_state, expected_state):
(HardwareState.ON, SoftwareState.OVERWHELMED),
],
)
def test_service_state_change_if_not_comprised(
operating_state, expected_state
):
def test_service_state_change_if_not_comprised(operating_state, expected_state):
"""
Test that a node cannot change the state of a running service.
@@ -67,8 +65,6 @@ def test_service_state_change_if_not_comprised(
service = Service("TCP", 80, SoftwareState.GOOD)
service_node.add_service(service)
service_node.set_service_state_if_not_compromised(
"TCP", SoftwareState.OVERWHELMED
)
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
assert service_node.get_service_state("TCP") == expected_state

View File

@@ -18,7 +18,6 @@ def run_generic_set_actions(env: Primaite):
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1]
# Creates an ACL rule
@@ -86,8 +85,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
]
],

View File

@@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT
def test_legacy_lay_down_config_yaml_conversion():
"""Tests the conversion of legacy lay down config files."""
legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
with open(legacy_path, "r") as file:
legacy_dict = yaml.safe_load(file)
@@ -16,9 +16,7 @@ def test_legacy_lay_down_config_yaml_conversion():
with open(new_path, "r") as file:
new_dict = yaml.safe_load(file)
converted_dict = training_config.convert_legacy_training_config_dict(
legacy_dict
)
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
for key, value in new_dict.items():
assert converted_dict[key] == value
@@ -26,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion():
def test_create_config_values_main_from_file():
"""Tests creating an instance of TrainingConfig from file."""
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
training_config.load(new_path)
def test_create_config_values_main_from_legacy_file():
"""Tests creating an instance of TrainingConfig from legacy file."""
new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
training_config.load(new_path, legacy_file=True)