From 34b294f89a5224207d3154d79232e525a49731e0 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 3 Jul 2023 20:40:38 +0100 Subject: [PATCH] #917 - Reinstalled the pre-commit hook --- src/primaite/environment/observations.py | 4 +- src/primaite/environment/primaite_env.py | 71 +++++++----------------- src/primaite/transactions/transaction.py | 15 +---- tests/test_red_random_agent_behaviour.py | 1 + 4 files changed, 24 insertions(+), 67 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 511fb008..23bc4a39 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -266,9 +266,7 @@ class NodeStatuses(AbstractObservationComponent): for service in services: structure.append(f"node_{node_id}_service_{service}_state_NONE") for state in SoftwareState: - structure.append( - f"node_{node_id}_service_{service}_state_{state.name}" - ) + structure.append(f"node_{node_id}_service_{service}_state_{state.name}") return structure diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index d7b68045..03c23f93 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,14 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy -import csv import logging import uuid as uuid -from datetime import datetime from pathlib import Path -from typing import Dict, Final, Tuple, Union from random import choice, randint, sample, uniform -from typing import Dict, Tuple, Union +from typing import Dict, Final, Tuple, Union import networkx as nx import numpy as np @@ -321,11 +318,7 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction( - self.agent_identifier, - self.actual_episode_count, - self.step_count - ) + transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) # Load the initial observation space into the transaction transaction.obs_space = self.obs_handler._flat_observation @@ -436,12 +429,7 @@ class Primaite(Env): for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - print( - " Protocol: " - + protocol.get_name().name - + ", Load: " - + str(protocol.get_load()) - ) + print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): """ @@ -456,13 +444,9 @@ class Primaite(Env): self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 6 - ): # ACL actions in multidiscrete form have len 6 + elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 4 - ): # Node actions in multdiscrete (array) from have len 4 + elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: logging.error("Invalid action type found") @@ -528,9 +512,7 @@ class Primaite(Env): return elif property_action == 1: # Patch (valid action if it's good or compromised) - node.set_service_state( - self.services_list[service_index], SoftwareState.PATCHING - ) + node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING) else: # Node is not of Service Type return @@ -1238,11 +1220,7 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = { - k + len(acl_action_dict) - 1: v - for k, v in node_action_dict.items() - if k != 0 - } + new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} @@ -1256,11 +1234,8 @@ class Primaite(Env): # Decide how many nodes become compromised node_list = list(self.nodes.values()) - computers = [node for node in node_list if - node.node_type == NodeType.COMPUTER] - max_num_nodes_compromised = len( - computers - ) # only computers can become compromised + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len(computers) # only computers can become compromised # random select between 1 and max_num_nodes_compromised num_nodes_to_compromise = randint(1, max_num_nodes_compromised) @@ -1271,9 +1246,7 @@ class Primaite(Env): source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised - max_step_compromised = ( - self.episode_steps // 2 - ) # always compromise in first half of episode + max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] @@ -1283,15 +1256,13 @@ class Primaite(Env): _LOGGER.error(msg) raise Exception(msg) - servers = [node for node in node_list if - node.node_type == NodeType.SERVER] + servers = [node for node in node_list if node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint(2, - max_step_compromised + 1) # step compromised + _start_step = randint(2, max_step_compromised + 1) # step compromised pol_service_name = choice(list(node.services.keys())) source_node_service = choice(list(source_node.services.values())) @@ -1316,8 +1287,7 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint(_start_step + 2, - int(self.episode_steps * 0.8)) + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith @@ -1325,9 +1295,7 @@ class Primaite(Env): ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port - ier_mission_criticality = ( - 0 # Red IER will never be important to green agent success - ) + ier_mission_criticality = 0 # Red IER will never be important to green agent success # We choose a node to attack based on the first that applies: # a. Green IERs, select dest node of the red ier based on dest node of green IER # b. Attack a random server that doesn't have a DENY acl rule in default config @@ -1340,16 +1308,15 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: # If still none found choose from all servers - possible_ier_destinations = [server.node_id for server in - servers] + possible_ier_destinations = [server.node_id for server in servers] ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 763dc458..7db2444a 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -9,12 +9,7 @@ from primaite.common.enums import AgentIdentifier class Transaction(object): """Transaction class.""" - def __init__( - self, - agent_identifier: AgentIdentifier, - episode_number: int, - step_number: int - ): + def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int): """ Transaction constructor. @@ -62,18 +57,14 @@ class Transaction(object): # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] header = header + action_header + self.obs_space_description - + row = [ str(self.timestamp), str(self.episode_number), str(self.step_number), str(self.reward), ] - row = ( - row - + _turn_action_space_to_array(self.action_space) - + self.obs_space.tolist() - ) + row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist() return header, row diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 8cf60236..f8885f3e 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -4,6 +4,7 @@ from primaite.config.lay_down_config import data_manipulation_config_path from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT + @pytest.mark.parametrize( "temp_primaite_session", [