diff --git a/.gitignore b/.gitignore index 5adbdc57..b65d1fd8 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,4 @@ dmypy.json # Cython debug symbols cython_debug/ -.idea/ \ No newline at end of file +.idea/ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 58932c4c..eb0bc5de 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,14 +3,14 @@ import copy import csv import logging +import uuid as uuid from datetime import datetime from pathlib import Path -from random import randint, choice, uniform, sample +from random import choice, randint, sample, uniform from typing import Dict, Tuple, Union import networkx as nx import numpy as np -import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -60,12 +60,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -448,11 +448,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1238,7 +1238,6 @@ class Primaite(Env): def create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" - # Reset the current red iers and red node pol self.red_iers = {} self.red_node_pol = {} @@ -1260,7 +1259,7 @@ class Primaite(Env): # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1277,16 +1276,10 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint( - 2, max_step_compromised + 1 - ) # step compromised - pol_service_name = choice( - list(node.services.keys()) - ) + _start_step = randint(2, max_step_compromised + 1) # step compromised + pol_service_name = choice(list(node.services.keys())) - source_node_service = choice( - list(source_node.services.values()) - ) + source_node_service = choice(list(source_node.services.values())) red_pol = NodeStateInstructionRed( _id=_id, @@ -1299,7 +1292,7 @@ class Primaite(Env): _pol_state=SoftwareState.COMPROMISED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[_id] = red_pol @@ -1308,15 +1301,11 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint( - _start_step + 2, int(self.episode_steps * 0.8) - ) + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = uniform(0.4, 0.8) * choice( - bandwidths - ) + ier_load = uniform(0.4, 0.8) * choice(bandwidths) ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port @@ -1335,10 +1324,10 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1376,6 +1365,6 @@ class Primaite(Env): _pol_state=SoftwareState.OVERWHELMED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 9ae917e9..2f7d0622 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -153,4 +153,4 @@ class NodeStateInstructionRed(object): f"source_node_service={self.source_node_service}, " f"source_node_service_state={self.source_node_service_state}" f")" - ) \ No newline at end of file + ) diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index c9189c26..476a08f1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -2,6 +2,7 @@ from datetime import datetime from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite +from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -41,14 +42,14 @@ def test_random_red_agent_behaviour(): # RUN TWICE so we can make sure that red agent is randomised for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() session_path = _get_temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( - training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, @@ -64,7 +65,7 @@ def test_random_red_agent_behaviour(): # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): for key in list_of_node_instructions[index].keys(): - instruction: NodeInstructionRed = list_of_node_instructions[index][key] + instruction: NodeStateInstructionRed = list_of_node_instructions[index][key] print(f"run {index}") print(f"{key} start step: {instruction.get_start_step()}") print(f"{key} end step: {instruction.get_end_step()}")