#917 - Reinstalled the pre-commit hook

This commit is contained in:
Chris McCarthy
2023-07-03 20:40:38 +01:00
parent 410d5abe12
commit 34b294f89a
4 changed files with 24 additions and 67 deletions

View File

@@ -266,9 +266,7 @@ class NodeStatuses(AbstractObservationComponent):
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(
f"node_{node_id}_service_{service}_state_{state.name}"
)
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
return structure

View File

@@ -1,14 +1,11 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
import copy
import csv
import logging
import uuid as uuid
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Tuple, Union
from random import choice, randint, sample, uniform
from typing import Dict, Tuple, Union
from typing import Dict, Final, Tuple, Union
import networkx as nx
import numpy as np
@@ -321,11 +318,7 @@ class Primaite(Env):
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
self.agent_identifier,
self.actual_episode_count,
self.step_count
)
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
# Load the initial observation space into the transaction
transaction.obs_space = self.obs_handler._flat_observation
@@ -436,12 +429,7 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
print(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
+ str(protocol.get_load())
)
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
"""
@@ -456,13 +444,9 @@ class Primaite(Env):
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
@@ -528,9 +512,7 @@ class Primaite(Env):
return
elif property_action == 1:
# Patch (valid action if it's good or compromised)
node.set_service_state(
self.services_list[service_index], SoftwareState.PATCHING
)
node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING)
else:
# Node is not of Service Type
return
@@ -1238,11 +1220,7 @@ class Primaite(Env):
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
@@ -1256,11 +1234,8 @@ class Primaite(Env):
# Decide how many nodes become compromised
node_list = list(self.nodes.values())
computers = [node for node in node_list if
node.node_type == NodeType.COMPUTER]
max_num_nodes_compromised = len(
computers
) # only computers can become compromised
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
max_num_nodes_compromised = len(computers) # only computers can become compromised
# random select between 1 and max_num_nodes_compromised
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
@@ -1271,9 +1246,7 @@ class Primaite(Env):
source_node = choice(nodes_to_be_compromised)
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
) # always compromise in first half of episode
max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode
# Bandwidth for all links
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
@@ -1283,15 +1256,13 @@ class Primaite(Env):
_LOGGER.error(msg)
raise Exception(msg)
servers = [node for node in node_list if
node.node_type == NodeType.SERVER]
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
for n, node in enumerate(nodes_to_be_compromised):
# 1: Use Node PoL to set node to compromised
_id = str(uuid.uuid4())
_start_step = randint(2,
max_step_compromised + 1) # step compromised
_start_step = randint(2, max_step_compromised + 1) # step compromised
pol_service_name = choice(list(node.services.keys()))
source_node_service = choice(list(source_node.services.values()))
@@ -1316,8 +1287,7 @@ class Primaite(Env):
ier_id = str(uuid.uuid4())
# Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = randint(_start_step + 2,
int(self.episode_steps * 0.8))
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
ier_end_step = self.episode_steps
# Randomise the load, as a percentage of a random link bandwith
@@ -1325,9 +1295,7 @@ class Primaite(Env):
ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.services[pol_service_name]
ier_port = ier_service.port
ier_mission_criticality = (
0 # Red IER will never be important to green agent success
)
ier_mission_criticality = 0 # Red IER will never be important to green agent success
# We choose a node to attack based on the first that applies:
# a. Green IERs, select dest node of the red ier based on dest node of green IER
# b. Attack a random server that doesn't have a DENY acl rule in default config
@@ -1340,16 +1308,15 @@ class Primaite(Env):
if len(possible_ier_destinations) < 1:
for server in servers:
if not self.acl.is_blocked(
node.ip_address,
server.ip_address,
ier_service,
ier_port,
node.ip_address,
server.ip_address,
ier_service,
ier_port,
):
possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1:
# If still none found choose from all servers
possible_ier_destinations = [server.node_id for server in
servers]
possible_ier_destinations = [server.node_id for server in servers]
ier_dest = choice(possible_ier_destinations)
self.red_iers[ier_id] = IER(
ier_id,

View File

@@ -9,12 +9,7 @@ from primaite.common.enums import AgentIdentifier
class Transaction(object):
"""Transaction class."""
def __init__(
self,
agent_identifier: AgentIdentifier,
episode_number: int,
step_number: int
):
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int):
"""
Transaction constructor.
@@ -62,18 +57,14 @@ class Transaction(object):
# Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + self.obs_space_description
row = [
str(self.timestamp),
str(self.episode_number),
str(self.step_number),
str(self.reward),
]
row = (
row
+ _turn_action_space_to_array(self.action_space)
+ self.obs_space.tolist()
)
row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist()
return header, row

View File

@@ -4,6 +4,7 @@ from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from tests import TEST_CONFIG_ROOT
@pytest.mark.parametrize(
"temp_primaite_session",
[