From 4299170ce42e68e392f5be2e1ef646c62546c971 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 09:46:52 +0100 Subject: [PATCH] #1522: added a check for existing links in laydown + test that checks if red agent instructions are random --- src/primaite/environment/primaite_env.py | 6 ++++ tests/test_red_random_agent_behaviour.py | 36 ++++++++++++------------ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e592e21f..58932c4c 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1265,6 +1265,12 @@ class Primaite(Env): # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + + if len(bandwidths) < 1: + msg = "Random red agent cannot be used on a network without any links" + _LOGGER.error(msg) + raise Exception(msg) + servers = [node for node in node_list if node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index a86e32c1..c9189c26 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,5 +1,6 @@ -from datetime import time, datetime +from datetime import datetime +from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -24,9 +25,6 @@ def run_generic(env, config_values): if done: break - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - # Reset the environment at the end of the episode env.reset() @@ -40,6 +38,8 @@ def test_random_red_agent_behaviour(): When the initial state is OFF compared to reference state which is ON. """ list_of_node_instructions = [] + + # RUN TWICE so we can make sure that red agent is randomised for i in range(2): """Takes a config path and returns the created instance of Primaite.""" @@ -49,7 +49,7 @@ def test_random_red_agent_behaviour(): timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, @@ -57,18 +57,18 @@ def test_random_red_agent_behaviour(): training_config = env.training_config training_config.num_steps = env.episode_steps - # TOOD: This needs t be refactored to happen outside. Should be part of - # a main Session class. - if training_config.agent_identifier == "GENERIC": - run_generic(env, training_config) - all_red_actions = env.red_node_pol - list_of_node_instructions.append(all_red_actions) + run_generic(env, training_config) + # add red pol instructions to list + list_of_node_instructions.append(env.red_node_pol) + + # compare instructions to make sure that red instructions are truly random + for index, instruction in enumerate(list_of_node_instructions): + for key in list_of_node_instructions[index].keys(): + instruction: NodeInstructionRed = list_of_node_instructions[index][key] + print(f"run {index}") + print(f"{key} start step: {instruction.get_start_step()}") + print(f"{key} end step: {instruction.get_end_step()}") + print(f"{key} target node id: {instruction.get_target_node_id()}") + print("") - # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) - print(list_of_node_instructions[0]["1"].get_start_step()) - print(list_of_node_instructions[0]["1"].get_end_step()) - print(list_of_node_instructions[0]["1"].get_target_node_id()) - print(list_of_node_instructions[1]["1"].get_start_step()) - print(list_of_node_instructions[1]["1"].get_end_step()) - print(list_of_node_instructions[1]["1"].get_target_node_id()) assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])