#1522: run pre-commit
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -138,4 +138,4 @@ dmypy.json
|
|||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
.idea/
|
.idea/
|
||||||
|
|||||||
@@ -3,14 +3,14 @@
|
|||||||
import copy
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
import uuid as uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import randint, choice, uniform, sample
|
from random import choice, randint, sample, uniform
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import uuid as uuid
|
|
||||||
import yaml
|
import yaml
|
||||||
from gym import Env, spaces
|
from gym import Env, spaces
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
@@ -60,12 +60,12 @@ class Primaite(Env):
|
|||||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
training_config_path: Union[str, Path],
|
training_config_path: Union[str, Path],
|
||||||
lay_down_config_path: Union[str, Path],
|
lay_down_config_path: Union[str, Path],
|
||||||
transaction_list,
|
transaction_list,
|
||||||
session_path: Path,
|
session_path: Path,
|
||||||
timestamp_str: str,
|
timestamp_str: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The Primaite constructor.
|
The Primaite constructor.
|
||||||
@@ -448,11 +448,11 @@ class Primaite(Env):
|
|||||||
elif self.training_config.action_type == ActionType.ACL:
|
elif self.training_config.action_type == ActionType.ACL:
|
||||||
self.apply_actions_to_acl(_action)
|
self.apply_actions_to_acl(_action)
|
||||||
elif (
|
elif (
|
||||||
len(self.action_dict[_action]) == 6
|
len(self.action_dict[_action]) == 6
|
||||||
): # ACL actions in multidiscrete form have len 6
|
): # ACL actions in multidiscrete form have len 6
|
||||||
self.apply_actions_to_acl(_action)
|
self.apply_actions_to_acl(_action)
|
||||||
elif (
|
elif (
|
||||||
len(self.action_dict[_action]) == 4
|
len(self.action_dict[_action]) == 4
|
||||||
): # Node actions in multdiscrete (array) from have len 4
|
): # Node actions in multdiscrete (array) from have len 4
|
||||||
self.apply_actions_to_nodes(_action)
|
self.apply_actions_to_nodes(_action)
|
||||||
else:
|
else:
|
||||||
@@ -1238,7 +1238,6 @@ class Primaite(Env):
|
|||||||
|
|
||||||
def create_random_red_agent(self):
|
def create_random_red_agent(self):
|
||||||
"""Decide on random red agent for the episode to be called in env.reset()."""
|
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||||
|
|
||||||
# Reset the current red iers and red node pol
|
# Reset the current red iers and red node pol
|
||||||
self.red_iers = {}
|
self.red_iers = {}
|
||||||
self.red_node_pol = {}
|
self.red_node_pol = {}
|
||||||
@@ -1260,7 +1259,7 @@ class Primaite(Env):
|
|||||||
|
|
||||||
# For each of the nodes to be compromised decide which step they become compromised
|
# For each of the nodes to be compromised decide which step they become compromised
|
||||||
max_step_compromised = (
|
max_step_compromised = (
|
||||||
self.episode_steps // 2
|
self.episode_steps // 2
|
||||||
) # always compromise in first half of episode
|
) # always compromise in first half of episode
|
||||||
|
|
||||||
# Bandwidth for all links
|
# Bandwidth for all links
|
||||||
@@ -1277,16 +1276,10 @@ class Primaite(Env):
|
|||||||
# 1: Use Node PoL to set node to compromised
|
# 1: Use Node PoL to set node to compromised
|
||||||
|
|
||||||
_id = str(uuid.uuid4())
|
_id = str(uuid.uuid4())
|
||||||
_start_step = randint(
|
_start_step = randint(2, max_step_compromised + 1) # step compromised
|
||||||
2, max_step_compromised + 1
|
pol_service_name = choice(list(node.services.keys()))
|
||||||
) # step compromised
|
|
||||||
pol_service_name = choice(
|
|
||||||
list(node.services.keys())
|
|
||||||
)
|
|
||||||
|
|
||||||
source_node_service = choice(
|
source_node_service = choice(list(source_node.services.values()))
|
||||||
list(source_node.services.values())
|
|
||||||
)
|
|
||||||
|
|
||||||
red_pol = NodeStateInstructionRed(
|
red_pol = NodeStateInstructionRed(
|
||||||
_id=_id,
|
_id=_id,
|
||||||
@@ -1299,7 +1292,7 @@ class Primaite(Env):
|
|||||||
_pol_state=SoftwareState.COMPROMISED,
|
_pol_state=SoftwareState.COMPROMISED,
|
||||||
_pol_source_node_id=source_node.node_id,
|
_pol_source_node_id=source_node.node_id,
|
||||||
_pol_source_node_service=source_node_service.name,
|
_pol_source_node_service=source_node_service.name,
|
||||||
_pol_source_node_service_state=source_node_service.software_state
|
_pol_source_node_service_state=source_node_service.software_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.red_node_pol[_id] = red_pol
|
self.red_node_pol[_id] = red_pol
|
||||||
@@ -1308,15 +1301,11 @@ class Primaite(Env):
|
|||||||
|
|
||||||
ier_id = str(uuid.uuid4())
|
ier_id = str(uuid.uuid4())
|
||||||
# Launch the attack after node is compromised, and not right at the end of the episode
|
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||||
ier_start_step = randint(
|
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
|
||||||
_start_step + 2, int(self.episode_steps * 0.8)
|
|
||||||
)
|
|
||||||
ier_end_step = self.episode_steps
|
ier_end_step = self.episode_steps
|
||||||
|
|
||||||
# Randomise the load, as a percentage of a random link bandwith
|
# Randomise the load, as a percentage of a random link bandwith
|
||||||
ier_load = uniform(0.4, 0.8) * choice(
|
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
|
||||||
bandwidths
|
|
||||||
)
|
|
||||||
ier_protocol = pol_service_name # Same protocol as compromised node
|
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||||
ier_service = node.services[pol_service_name]
|
ier_service = node.services[pol_service_name]
|
||||||
ier_port = ier_service.port
|
ier_port = ier_service.port
|
||||||
@@ -1335,10 +1324,10 @@ class Primaite(Env):
|
|||||||
if len(possible_ier_destinations) < 1:
|
if len(possible_ier_destinations) < 1:
|
||||||
for server in servers:
|
for server in servers:
|
||||||
if not self.acl.is_blocked(
|
if not self.acl.is_blocked(
|
||||||
node.ip_address,
|
node.ip_address,
|
||||||
server.ip_address,
|
server.ip_address,
|
||||||
ier_service,
|
ier_service,
|
||||||
ier_port,
|
ier_port,
|
||||||
):
|
):
|
||||||
possible_ier_destinations.append(server.node_id)
|
possible_ier_destinations.append(server.node_id)
|
||||||
if len(possible_ier_destinations) < 1:
|
if len(possible_ier_destinations) < 1:
|
||||||
@@ -1376,6 +1365,6 @@ class Primaite(Env):
|
|||||||
_pol_state=SoftwareState.OVERWHELMED,
|
_pol_state=SoftwareState.OVERWHELMED,
|
||||||
_pol_source_node_id=source_node.node_id,
|
_pol_source_node_id=source_node.node_id,
|
||||||
_pol_source_node_service=source_node_service.name,
|
_pol_source_node_service=source_node_service.name,
|
||||||
_pol_source_node_service_state=source_node_service.software_state
|
_pol_source_node_service_state=source_node_service.software_state,
|
||||||
)
|
)
|
||||||
self.red_node_pol[o_pol_id] = o_red_pol
|
self.red_node_pol[o_pol_id] = o_red_pol
|
||||||
|
|||||||
@@ -153,4 +153,4 @@ class NodeStateInstructionRed(object):
|
|||||||
f"source_node_service={self.source_node_service}, "
|
f"source_node_service={self.source_node_service}, "
|
||||||
f"source_node_service_state={self.source_node_service_state}"
|
f"source_node_service_state={self.source_node_service_state}"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from primaite.config.lay_down_config import data_manipulation_config_path
|
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||||
from primaite.environment.primaite_env import Primaite
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||||
from tests import TEST_CONFIG_ROOT
|
from tests import TEST_CONFIG_ROOT
|
||||||
from tests.conftest import _get_temp_session_path
|
from tests.conftest import _get_temp_session_path
|
||||||
|
|
||||||
@@ -41,14 +42,14 @@ def test_random_red_agent_behaviour():
|
|||||||
|
|
||||||
# RUN TWICE so we can make sure that red agent is randomised
|
# RUN TWICE so we can make sure that red agent is randomised
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
|
|
||||||
"""Takes a config path and returns the created instance of Primaite."""
|
"""Takes a config path and returns the created instance of Primaite."""
|
||||||
session_timestamp: datetime = datetime.now()
|
session_timestamp: datetime = datetime.now()
|
||||||
session_path = _get_temp_session_path(session_timestamp)
|
session_path = _get_temp_session_path(session_timestamp)
|
||||||
|
|
||||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
env = Primaite(
|
env = Primaite(
|
||||||
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
training_config_path=TEST_CONFIG_ROOT
|
||||||
|
/ "one_node_states_on_off_main_config.yaml",
|
||||||
lay_down_config_path=data_manipulation_config_path(),
|
lay_down_config_path=data_manipulation_config_path(),
|
||||||
transaction_list=[],
|
transaction_list=[],
|
||||||
session_path=session_path,
|
session_path=session_path,
|
||||||
@@ -64,7 +65,7 @@ def test_random_red_agent_behaviour():
|
|||||||
# compare instructions to make sure that red instructions are truly random
|
# compare instructions to make sure that red instructions are truly random
|
||||||
for index, instruction in enumerate(list_of_node_instructions):
|
for index, instruction in enumerate(list_of_node_instructions):
|
||||||
for key in list_of_node_instructions[index].keys():
|
for key in list_of_node_instructions[index].keys():
|
||||||
instruction: NodeInstructionRed = list_of_node_instructions[index][key]
|
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
|
||||||
print(f"run {index}")
|
print(f"run {index}")
|
||||||
print(f"{key} start step: {instruction.get_start_step()}")
|
print(f"{key} start step: {instruction.get_start_step()}")
|
||||||
print(f"{key} end step: {instruction.get_end_step()}")
|
print(f"{key} end step: {instruction.get_end_step()}")
|
||||||
|
|||||||
Reference in New Issue
Block a user