#1522: run pre-commit

This commit is contained in:
Czar Echavez
2023-07-03 10:08:25 +01:00
parent ae56827bae
commit 6c4a538b41
4 changed files with 28 additions and 38 deletions

2
.gitignore vendored
View File

@@ -138,4 +138,4 @@ dmypy.json
# Cython debug symbols
cython_debug/
.idea/
.idea/

View File

@@ -3,14 +3,14 @@
import copy
import csv
import logging
import uuid as uuid
from datetime import datetime
from pathlib import Path
from random import randint, choice, uniform, sample
from random import choice, randint, sample, uniform
from typing import Dict, Tuple, Union
import networkx as nx
import numpy as np
import uuid as uuid
import yaml
from gym import Env, spaces
from matplotlib import pyplot as plt
@@ -60,12 +60,12 @@ class Primaite(Env):
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
):
"""
The Primaite constructor.
@@ -448,11 +448,11 @@ class Primaite(Env):
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
@@ -1238,7 +1238,6 @@ class Primaite(Env):
def create_random_red_agent(self):
"""Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol
self.red_iers = {}
self.red_node_pol = {}
@@ -1260,7 +1259,7 @@ class Primaite(Env):
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
self.episode_steps // 2
) # always compromise in first half of episode
# Bandwidth for all links
@@ -1277,16 +1276,10 @@ class Primaite(Env):
# 1: Use Node PoL to set node to compromised
_id = str(uuid.uuid4())
_start_step = randint(
2, max_step_compromised + 1
) # step compromised
pol_service_name = choice(
list(node.services.keys())
)
_start_step = randint(2, max_step_compromised + 1) # step compromised
pol_service_name = choice(list(node.services.keys()))
source_node_service = choice(
list(source_node.services.values())
)
source_node_service = choice(list(source_node.services.values()))
red_pol = NodeStateInstructionRed(
_id=_id,
@@ -1299,7 +1292,7 @@ class Primaite(Env):
_pol_state=SoftwareState.COMPROMISED,
_pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state
_pol_source_node_service_state=source_node_service.software_state,
)
self.red_node_pol[_id] = red_pol
@@ -1308,15 +1301,11 @@ class Primaite(Env):
ier_id = str(uuid.uuid4())
# Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = randint(
_start_step + 2, int(self.episode_steps * 0.8)
)
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
ier_end_step = self.episode_steps
# Randomise the load, as a percentage of a random link bandwith
ier_load = uniform(0.4, 0.8) * choice(
bandwidths
)
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.services[pol_service_name]
ier_port = ier_service.port
@@ -1335,10 +1324,10 @@ class Primaite(Env):
if len(possible_ier_destinations) < 1:
for server in servers:
if not self.acl.is_blocked(
node.ip_address,
server.ip_address,
ier_service,
ier_port,
node.ip_address,
server.ip_address,
ier_service,
ier_port,
):
possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1:
@@ -1376,6 +1365,6 @@ class Primaite(Env):
_pol_state=SoftwareState.OVERWHELMED,
_pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state
_pol_source_node_service_state=source_node_service.software_state,
)
self.red_node_pol[o_pol_id] = o_red_pol

View File

@@ -153,4 +153,4 @@ class NodeStateInstructionRed(object):
f"source_node_service={self.source_node_service}, "
f"source_node_service_state={self.source_node_service_state}"
f")"
)
)

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.environment.primaite_env import Primaite
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_temp_session_path
@@ -41,14 +42,14 @@ def test_random_red_agent_behaviour():
# RUN TWICE so we can make sure that red agent is randomised
for i in range(2):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=data_manipulation_config_path(),
transaction_list=[],
session_path=session_path,
@@ -64,7 +65,7 @@ def test_random_red_agent_behaviour():
# compare instructions to make sure that red instructions are truly random
for index, instruction in enumerate(list_of_node_instructions):
for key in list_of_node_instructions[index].keys():
instruction: NodeInstructionRed = list_of_node_instructions[index][key]
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
print(f"run {index}")
print(f"{key} start step: {instruction.get_start_step()}")
print(f"{key} end step: {instruction.get_end_step()}")