#1522: run pre-commit
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -138,4 +138,4 @@ dmypy.json
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
.idea/
|
||||
.idea/
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
import copy
|
||||
import csv
|
||||
import logging
|
||||
import uuid as uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from random import randint, choice, uniform, sample
|
||||
from random import choice, randint, sample, uniform
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import uuid as uuid
|
||||
import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
@@ -60,12 +60,12 @@ class Primaite(Env):
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
"""
|
||||
The Primaite constructor.
|
||||
@@ -448,11 +448,11 @@ class Primaite(Env):
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
@@ -1238,7 +1238,6 @@ class Primaite(Env):
|
||||
|
||||
def create_random_red_agent(self):
|
||||
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||
|
||||
# Reset the current red iers and red node pol
|
||||
self.red_iers = {}
|
||||
self.red_node_pol = {}
|
||||
@@ -1260,7 +1259,7 @@ class Primaite(Env):
|
||||
|
||||
# For each of the nodes to be compromised decide which step they become compromised
|
||||
max_step_compromised = (
|
||||
self.episode_steps // 2
|
||||
self.episode_steps // 2
|
||||
) # always compromise in first half of episode
|
||||
|
||||
# Bandwidth for all links
|
||||
@@ -1277,16 +1276,10 @@ class Primaite(Env):
|
||||
# 1: Use Node PoL to set node to compromised
|
||||
|
||||
_id = str(uuid.uuid4())
|
||||
_start_step = randint(
|
||||
2, max_step_compromised + 1
|
||||
) # step compromised
|
||||
pol_service_name = choice(
|
||||
list(node.services.keys())
|
||||
)
|
||||
_start_step = randint(2, max_step_compromised + 1) # step compromised
|
||||
pol_service_name = choice(list(node.services.keys()))
|
||||
|
||||
source_node_service = choice(
|
||||
list(source_node.services.values())
|
||||
)
|
||||
source_node_service = choice(list(source_node.services.values()))
|
||||
|
||||
red_pol = NodeStateInstructionRed(
|
||||
_id=_id,
|
||||
@@ -1299,7 +1292,7 @@ class Primaite(Env):
|
||||
_pol_state=SoftwareState.COMPROMISED,
|
||||
_pol_source_node_id=source_node.node_id,
|
||||
_pol_source_node_service=source_node_service.name,
|
||||
_pol_source_node_service_state=source_node_service.software_state
|
||||
_pol_source_node_service_state=source_node_service.software_state,
|
||||
)
|
||||
|
||||
self.red_node_pol[_id] = red_pol
|
||||
@@ -1308,15 +1301,11 @@ class Primaite(Env):
|
||||
|
||||
ier_id = str(uuid.uuid4())
|
||||
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||
ier_start_step = randint(
|
||||
_start_step + 2, int(self.episode_steps * 0.8)
|
||||
)
|
||||
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
|
||||
ier_end_step = self.episode_steps
|
||||
|
||||
# Randomise the load, as a percentage of a random link bandwith
|
||||
ier_load = uniform(0.4, 0.8) * choice(
|
||||
bandwidths
|
||||
)
|
||||
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
|
||||
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||
ier_service = node.services[pol_service_name]
|
||||
ier_port = ier_service.port
|
||||
@@ -1335,10 +1324,10 @@ class Primaite(Env):
|
||||
if len(possible_ier_destinations) < 1:
|
||||
for server in servers:
|
||||
if not self.acl.is_blocked(
|
||||
node.ip_address,
|
||||
server.ip_address,
|
||||
ier_service,
|
||||
ier_port,
|
||||
node.ip_address,
|
||||
server.ip_address,
|
||||
ier_service,
|
||||
ier_port,
|
||||
):
|
||||
possible_ier_destinations.append(server.node_id)
|
||||
if len(possible_ier_destinations) < 1:
|
||||
@@ -1376,6 +1365,6 @@ class Primaite(Env):
|
||||
_pol_state=SoftwareState.OVERWHELMED,
|
||||
_pol_source_node_id=source_node.node_id,
|
||||
_pol_source_node_service=source_node_service.name,
|
||||
_pol_source_node_service_state=source_node_service.software_state
|
||||
_pol_source_node_service_state=source_node_service.software_state,
|
||||
)
|
||||
self.red_node_pol[o_pol_id] = o_red_pol
|
||||
|
||||
@@ -153,4 +153,4 @@ class NodeStateInstructionRed(object):
|
||||
f"source_node_service={self.source_node_service}, "
|
||||
f"source_node_service_state={self.source_node_service_state}"
|
||||
f")"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
||||
|
||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_temp_session_path
|
||||
|
||||
@@ -41,14 +42,14 @@ def test_random_red_agent_behaviour():
|
||||
|
||||
# RUN TWICE so we can make sure that red agent is randomised
|
||||
for i in range(2):
|
||||
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
session_path = _get_temp_session_path(session_timestamp)
|
||||
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
env = Primaite(
|
||||
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_main_config.yaml",
|
||||
lay_down_config_path=data_manipulation_config_path(),
|
||||
transaction_list=[],
|
||||
session_path=session_path,
|
||||
@@ -64,7 +65,7 @@ def test_random_red_agent_behaviour():
|
||||
# compare instructions to make sure that red instructions are truly random
|
||||
for index, instruction in enumerate(list_of_node_instructions):
|
||||
for key in list_of_node_instructions[index].keys():
|
||||
instruction: NodeInstructionRed = list_of_node_instructions[index][key]
|
||||
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
|
||||
print(f"run {index}")
|
||||
print(f"{key} start step: {instruction.get_start_step()}")
|
||||
print(f"{key} end step: {instruction.get_end_step()}")
|
||||
|
||||
Reference in New Issue
Block a user