#1522: run pre-commit

This commit is contained in:
Czar Echavez
2023-07-03 10:08:25 +01:00
parent ae56827bae
commit 6c4a538b41
4 changed files with 28 additions and 38 deletions

2
.gitignore vendored
View File

@@ -138,4 +138,4 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
.idea/ .idea/

View File

@@ -3,14 +3,14 @@
import copy import copy
import csv import csv
import logging import logging
import uuid as uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from random import randint, choice, uniform, sample from random import choice, randint, sample, uniform
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import uuid as uuid
import yaml import yaml
from gym import Env, spaces from gym import Env, spaces
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
@@ -60,12 +60,12 @@ class Primaite(Env):
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__( def __init__(
self, self,
training_config_path: Union[str, Path], training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path], lay_down_config_path: Union[str, Path],
transaction_list, transaction_list,
session_path: Path, session_path: Path,
timestamp_str: str, timestamp_str: str,
): ):
""" """
The Primaite constructor. The Primaite constructor.
@@ -448,11 +448,11 @@ class Primaite(Env):
elif self.training_config.action_type == ActionType.ACL: elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action) self.apply_actions_to_acl(_action)
elif ( elif (
len(self.action_dict[_action]) == 6 len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6 ): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action) self.apply_actions_to_acl(_action)
elif ( elif (
len(self.action_dict[_action]) == 4 len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4 ): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action) self.apply_actions_to_nodes(_action)
else: else:
@@ -1238,7 +1238,6 @@ class Primaite(Env):
def create_random_red_agent(self): def create_random_red_agent(self):
"""Decide on random red agent for the episode to be called in env.reset().""" """Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol # Reset the current red iers and red node pol
self.red_iers = {} self.red_iers = {}
self.red_node_pol = {} self.red_node_pol = {}
@@ -1260,7 +1259,7 @@ class Primaite(Env):
# For each of the nodes to be compromised decide which step they become compromised # For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = ( max_step_compromised = (
self.episode_steps // 2 self.episode_steps // 2
) # always compromise in first half of episode ) # always compromise in first half of episode
# Bandwidth for all links # Bandwidth for all links
@@ -1277,16 +1276,10 @@ class Primaite(Env):
# 1: Use Node PoL to set node to compromised # 1: Use Node PoL to set node to compromised
_id = str(uuid.uuid4()) _id = str(uuid.uuid4())
_start_step = randint( _start_step = randint(2, max_step_compromised + 1) # step compromised
2, max_step_compromised + 1 pol_service_name = choice(list(node.services.keys()))
) # step compromised
pol_service_name = choice(
list(node.services.keys())
)
source_node_service = choice( source_node_service = choice(list(source_node.services.values()))
list(source_node.services.values())
)
red_pol = NodeStateInstructionRed( red_pol = NodeStateInstructionRed(
_id=_id, _id=_id,
@@ -1299,7 +1292,7 @@ class Primaite(Env):
_pol_state=SoftwareState.COMPROMISED, _pol_state=SoftwareState.COMPROMISED,
_pol_source_node_id=source_node.node_id, _pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name, _pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state _pol_source_node_service_state=source_node_service.software_state,
) )
self.red_node_pol[_id] = red_pol self.red_node_pol[_id] = red_pol
@@ -1308,15 +1301,11 @@ class Primaite(Env):
ier_id = str(uuid.uuid4()) ier_id = str(uuid.uuid4())
# Launch the attack after node is compromised, and not right at the end of the episode # Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = randint( ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
_start_step + 2, int(self.episode_steps * 0.8)
)
ier_end_step = self.episode_steps ier_end_step = self.episode_steps
# Randomise the load, as a percentage of a random link bandwith # Randomise the load, as a percentage of a random link bandwith
ier_load = uniform(0.4, 0.8) * choice( ier_load = uniform(0.4, 0.8) * choice(bandwidths)
bandwidths
)
ier_protocol = pol_service_name # Same protocol as compromised node ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.services[pol_service_name] ier_service = node.services[pol_service_name]
ier_port = ier_service.port ier_port = ier_service.port
@@ -1335,10 +1324,10 @@ class Primaite(Env):
if len(possible_ier_destinations) < 1: if len(possible_ier_destinations) < 1:
for server in servers: for server in servers:
if not self.acl.is_blocked( if not self.acl.is_blocked(
node.ip_address, node.ip_address,
server.ip_address, server.ip_address,
ier_service, ier_service,
ier_port, ier_port,
): ):
possible_ier_destinations.append(server.node_id) possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1: if len(possible_ier_destinations) < 1:
@@ -1376,6 +1365,6 @@ class Primaite(Env):
_pol_state=SoftwareState.OVERWHELMED, _pol_state=SoftwareState.OVERWHELMED,
_pol_source_node_id=source_node.node_id, _pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name, _pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state _pol_source_node_service_state=source_node_service.software_state,
) )
self.red_node_pol[o_pol_id] = o_red_pol self.red_node_pol[o_pol_id] = o_red_pol

View File

@@ -153,4 +153,4 @@ class NodeStateInstructionRed(object):
f"source_node_service={self.source_node_service}, " f"source_node_service={self.source_node_service}, "
f"source_node_service_state={self.source_node_service_state}" f"source_node_service_state={self.source_node_service_state}"
f")" f")"
) )

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from primaite.config.lay_down_config import data_manipulation_config_path from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.environment.primaite_env import Primaite from primaite.environment.primaite_env import Primaite
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from tests import TEST_CONFIG_ROOT from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_temp_session_path from tests.conftest import _get_temp_session_path
@@ -41,14 +42,14 @@ def test_random_red_agent_behaviour():
# RUN TWICE so we can make sure that red agent is randomised # RUN TWICE so we can make sure that red agent is randomised
for i in range(2): for i in range(2):
"""Takes a config path and returns the created instance of Primaite.""" """Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now() session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp) session_path = _get_temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite( env = Primaite(
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=data_manipulation_config_path(), lay_down_config_path=data_manipulation_config_path(),
transaction_list=[], transaction_list=[],
session_path=session_path, session_path=session_path,
@@ -64,7 +65,7 @@ def test_random_red_agent_behaviour():
# compare instructions to make sure that red instructions are truly random # compare instructions to make sure that red instructions are truly random
for index, instruction in enumerate(list_of_node_instructions): for index, instruction in enumerate(list_of_node_instructions):
for key in list_of_node_instructions[index].keys(): for key in list_of_node_instructions[index].keys():
instruction: NodeInstructionRed = list_of_node_instructions[index][key] instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
print(f"run {index}") print(f"run {index}")
print(f"{key} start step: {instruction.get_start_step()}") print(f"{key} start step: {instruction.get_start_step()}")
print(f"{key} end step: {instruction.get_end_step()}") print(f"{key} end step: {instruction.get_end_step()}")