Merge remote-tracking branch 'origin/dev' into feature/1558-flatten-spaces
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -137,3 +137,5 @@ dmypy.json
|
|||||||
|
|
||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ The environment config file consists of the following attributes:
|
|||||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
||||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
||||||
|
|
||||||
|
* **random_red_agent** [bool]
|
||||||
|
|
||||||
|
Determines if the session should be run with a random red agent
|
||||||
|
|
||||||
* **action_type** [enum]
|
* **action_type** [enum]
|
||||||
|
|
||||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||||
|
|||||||
@@ -5,7 +5,12 @@
|
|||||||
# "STABLE_BASELINES3_PPO"
|
# "STABLE_BASELINES3_PPO"
|
||||||
# "STABLE_BASELINES3_A2C"
|
# "STABLE_BASELINES3_A2C"
|
||||||
# "GENERIC"
|
# "GENERIC"
|
||||||
agent_identifier: STABLE_BASELINES3_PPO
|
agent_identifier: STABLE_BASELINES3_A2C
|
||||||
|
|
||||||
|
# RED AGENT IDENTIFIER
|
||||||
|
# RANDOM or NONE
|
||||||
|
random_red_agent: False
|
||||||
|
|
||||||
# Sets How the Action Space is defined:
|
# Sets How the Action Space is defined:
|
||||||
# "NODE"
|
# "NODE"
|
||||||
# "ACL"
|
# "ACL"
|
||||||
|
|||||||
@@ -0,0 +1,99 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agent_identifier: STABLE_BASELINES3_A2C
|
||||||
|
|
||||||
|
# RED AGENT IDENTIFIER
|
||||||
|
# RANDOM or NONE
|
||||||
|
random_red_agent: True
|
||||||
|
|
||||||
|
# Sets How the Action Space is defined:
|
||||||
|
# "NODE"
|
||||||
|
# "ACL"
|
||||||
|
# "ANY" node and acl actions
|
||||||
|
action_type: NODE
|
||||||
|
# Number of episodes to run per session
|
||||||
|
num_episodes: 10
|
||||||
|
# Number of time_steps per episode
|
||||||
|
num_steps: 256
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
time_delay: 10
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
session_type: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
load_agent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observation_space_high_value: 1000000000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
all_ok: 0
|
||||||
|
# Node Hardware State
|
||||||
|
off_should_be_on: -10
|
||||||
|
off_should_be_resetting: -5
|
||||||
|
on_should_be_off: -2
|
||||||
|
on_should_be_resetting: -5
|
||||||
|
resetting_should_be_on: -5
|
||||||
|
resetting_should_be_off: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
good_should_be_patching: 2
|
||||||
|
good_should_be_compromised: 5
|
||||||
|
good_should_be_overwhelmed: 5
|
||||||
|
patching_should_be_good: -5
|
||||||
|
patching_should_be_compromised: 2
|
||||||
|
patching_should_be_overwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromised_should_be_good: -20
|
||||||
|
compromised_should_be_patching: -20
|
||||||
|
compromised_should_be_overwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmed_should_be_good: -20
|
||||||
|
overwhelmed_should_be_patching: -20
|
||||||
|
overwhelmed_should_be_compromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
good_should_be_repairing: 2
|
||||||
|
good_should_be_restoring: 2
|
||||||
|
good_should_be_corrupt: 5
|
||||||
|
good_should_be_destroyed: 10
|
||||||
|
repairing_should_be_good: -5
|
||||||
|
repairing_should_be_restoring: 2
|
||||||
|
repairing_should_be_corrupt: 2
|
||||||
|
repairing_should_be_destroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoring_should_be_good: -10
|
||||||
|
restoring_should_be_repairing: -2
|
||||||
|
restoring_should_be_corrupt: 1
|
||||||
|
restoring_should_be_destroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corrupt_should_be_good: -10
|
||||||
|
corrupt_should_be_repairing: -10
|
||||||
|
corrupt_should_be_restoring: -10
|
||||||
|
corrupt_should_be_destroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyed_should_be_good: -20
|
||||||
|
destroyed_should_be_repairing: -20
|
||||||
|
destroyed_should_be_restoring: -20
|
||||||
|
destroyed_should_be_corrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
red_ier_running: -5
|
||||||
|
green_ier_blocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
os_patching_duration: 5 # The time taken to patch the OS
|
||||||
|
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||||
|
service_patching_duration: 5 # The time taken to patch a service
|
||||||
|
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||||
|
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||||
|
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||||
@@ -21,6 +21,9 @@ class TrainingConfig:
|
|||||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||||
"The Red Agent algo/class to be used."
|
"The Red Agent algo/class to be used."
|
||||||
|
|
||||||
|
random_red_agent: bool = False
|
||||||
|
"Creates Random Red Agent Attacks"
|
||||||
|
|
||||||
action_type: ActionType = ActionType.ANY
|
action_type: ActionType = ActionType.ANY
|
||||||
"The ActionType to use."
|
"The ActionType to use."
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
import copy
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
import uuid as uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from random import choice, randint, sample, uniform
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -274,6 +276,10 @@ class Primaite(Env):
|
|||||||
# Does this for both live and reference nodes
|
# Does this for both live and reference nodes
|
||||||
self.reset_environment()
|
self.reset_environment()
|
||||||
|
|
||||||
|
# Create a random red agent to use for this episode
|
||||||
|
if self.training_config.random_red_agent:
|
||||||
|
self._create_random_red_agent()
|
||||||
|
|
||||||
# Reset counters and totals
|
# Reset counters and totals
|
||||||
self.total_reward = 0
|
self.total_reward = 0
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
@@ -1228,3 +1234,136 @@ class Primaite(Env):
|
|||||||
# Combine the Node dict and ACL dict
|
# Combine the Node dict and ACL dict
|
||||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||||
return combined_action_dict
|
return combined_action_dict
|
||||||
|
|
||||||
|
def _create_random_red_agent(self):
|
||||||
|
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||||
|
# Reset the current red iers and red node pol
|
||||||
|
self.red_iers = {}
|
||||||
|
self.red_node_pol = {}
|
||||||
|
|
||||||
|
# Decide how many nodes become compromised
|
||||||
|
node_list = list(self.nodes.values())
|
||||||
|
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
||||||
|
max_num_nodes_compromised = len(
|
||||||
|
computers
|
||||||
|
) # only computers can become compromised
|
||||||
|
# random select between 1 and max_num_nodes_compromised
|
||||||
|
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
|
||||||
|
|
||||||
|
# Decide which of the nodes to compromise
|
||||||
|
nodes_to_be_compromised = sample(computers, num_nodes_to_compromise)
|
||||||
|
|
||||||
|
# choose a random compromise node to be source of attacks
|
||||||
|
source_node = choice(nodes_to_be_compromised)
|
||||||
|
|
||||||
|
# For each of the nodes to be compromised decide which step they become compromised
|
||||||
|
max_step_compromised = (
|
||||||
|
self.episode_steps // 2
|
||||||
|
) # always compromise in first half of episode
|
||||||
|
|
||||||
|
# Bandwidth for all links
|
||||||
|
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
||||||
|
|
||||||
|
if len(bandwidths) < 1:
|
||||||
|
msg = "Random red agent cannot be used on a network without any links"
|
||||||
|
_LOGGER.error(msg)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
|
||||||
|
|
||||||
|
for n, node in enumerate(nodes_to_be_compromised):
|
||||||
|
# 1: Use Node PoL to set node to compromised
|
||||||
|
|
||||||
|
_id = str(uuid.uuid4())
|
||||||
|
_start_step = randint(2, max_step_compromised + 1) # step compromised
|
||||||
|
pol_service_name = choice(list(node.services.keys()))
|
||||||
|
|
||||||
|
source_node_service = choice(list(source_node.services.values()))
|
||||||
|
|
||||||
|
red_pol = NodeStateInstructionRed(
|
||||||
|
_id=_id,
|
||||||
|
_start_step=_start_step,
|
||||||
|
_end_step=_start_step, # only run for 1 step
|
||||||
|
_target_node_id=node.node_id,
|
||||||
|
_pol_initiator="DIRECT",
|
||||||
|
_pol_type=NodePOLType["SERVICE"],
|
||||||
|
pol_protocol=pol_service_name,
|
||||||
|
_pol_state=SoftwareState.COMPROMISED,
|
||||||
|
_pol_source_node_id=source_node.node_id,
|
||||||
|
_pol_source_node_service=source_node_service.name,
|
||||||
|
_pol_source_node_service_state=source_node_service.software_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.red_node_pol[_id] = red_pol
|
||||||
|
|
||||||
|
# 2: Launch the attack from compromised node - set the IER
|
||||||
|
|
||||||
|
ier_id = str(uuid.uuid4())
|
||||||
|
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||||
|
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
|
||||||
|
ier_end_step = self.episode_steps
|
||||||
|
|
||||||
|
# Randomise the load, as a percentage of a random link bandwith
|
||||||
|
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
|
||||||
|
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||||
|
ier_service = node.services[pol_service_name]
|
||||||
|
ier_port = ier_service.port
|
||||||
|
ier_mission_criticality = (
|
||||||
|
0 # Red IER will never be important to green agent success
|
||||||
|
)
|
||||||
|
# We choose a node to attack based on the first that applies:
|
||||||
|
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
||||||
|
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
||||||
|
# c. Attack a random server
|
||||||
|
possible_ier_destinations = [
|
||||||
|
ier.get_dest_node_id()
|
||||||
|
for ier in list(self.green_iers.values())
|
||||||
|
if ier.get_source_node_id() == node.node_id
|
||||||
|
]
|
||||||
|
if len(possible_ier_destinations) < 1:
|
||||||
|
for server in servers:
|
||||||
|
if not self.acl.is_blocked(
|
||||||
|
node.ip_address,
|
||||||
|
server.ip_address,
|
||||||
|
ier_service,
|
||||||
|
ier_port,
|
||||||
|
):
|
||||||
|
possible_ier_destinations.append(server.node_id)
|
||||||
|
if len(possible_ier_destinations) < 1:
|
||||||
|
# If still none found choose from all servers
|
||||||
|
possible_ier_destinations = [server.node_id for server in servers]
|
||||||
|
ier_dest = choice(possible_ier_destinations)
|
||||||
|
self.red_iers[ier_id] = IER(
|
||||||
|
ier_id,
|
||||||
|
ier_start_step,
|
||||||
|
ier_end_step,
|
||||||
|
ier_load,
|
||||||
|
ier_protocol,
|
||||||
|
ier_port,
|
||||||
|
node.node_id,
|
||||||
|
ier_dest,
|
||||||
|
ier_mission_criticality,
|
||||||
|
)
|
||||||
|
|
||||||
|
overwhelm_pol = red_pol
|
||||||
|
overwhelm_pol.id = str(uuid.uuid4())
|
||||||
|
overwhelm_pol.end_step = self.episode_steps
|
||||||
|
|
||||||
|
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
||||||
|
# # TODO remove duplicate red pol for same targetted service - must take into account start step
|
||||||
|
|
||||||
|
o_pol_id = str(uuid.uuid4())
|
||||||
|
o_red_pol = NodeStateInstructionRed(
|
||||||
|
_id=o_pol_id,
|
||||||
|
_start_step=ier_start_step,
|
||||||
|
_end_step=self.episode_steps,
|
||||||
|
_target_node_id=ier_dest,
|
||||||
|
_pol_initiator="DIRECT",
|
||||||
|
_pol_type=NodePOLType["SERVICE"],
|
||||||
|
pol_protocol=ier_protocol,
|
||||||
|
_pol_state=SoftwareState.OVERWHELMED,
|
||||||
|
_pol_source_node_id=source_node.node_id,
|
||||||
|
_pol_source_node_service=source_node_service.name,
|
||||||
|
_pol_source_node_service_state=source_node_service.software_state,
|
||||||
|
)
|
||||||
|
self.red_node_pol[o_pol_id] = o_red_pol
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""Defines node behaviour for Green PoL."""
|
"""Defines node behaviour for Green PoL."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from primaite.common.enums import NodePOLType
|
from primaite.common.enums import NodePOLType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
class NodeStateInstructionRed(object):
|
class NodeStateInstructionRed(object):
|
||||||
"""The Node State Instruction class."""
|
"""The Node State Instruction class."""
|
||||||
|
|
||||||
|
|||||||
77
tests/test_red_random_agent_behaviour.py
Normal file
77
tests/test_red_random_agent_behaviour.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||||
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||||
|
from tests import TEST_CONFIG_ROOT
|
||||||
|
from tests.conftest import _get_temp_session_path
|
||||||
|
|
||||||
|
|
||||||
|
def run_generic(env, config_values):
|
||||||
|
"""Run against a generic agent."""
|
||||||
|
# Reset the environment at the start of the episode
|
||||||
|
env.reset()
|
||||||
|
for episode in range(0, config_values.num_episodes):
|
||||||
|
for step in range(0, config_values.num_steps):
|
||||||
|
# Send the observation space to the agent to get an action
|
||||||
|
# TEMP - random action for now
|
||||||
|
# action = env.blue_agent_action(obs)
|
||||||
|
# action = env.action_space.sample()
|
||||||
|
action = 0
|
||||||
|
|
||||||
|
# Run the simulation step on the live environment
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Break if done is True
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset the environment at the end of the episode
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_red_agent_behaviour():
|
||||||
|
"""
|
||||||
|
Test that hardware state is penalised at each step.
|
||||||
|
|
||||||
|
When the initial state is OFF compared to reference state which is ON.
|
||||||
|
"""
|
||||||
|
list_of_node_instructions = []
|
||||||
|
|
||||||
|
# RUN TWICE so we can make sure that red agent is randomised
|
||||||
|
for i in range(2):
|
||||||
|
"""Takes a config path and returns the created instance of Primaite."""
|
||||||
|
session_timestamp: datetime = datetime.now()
|
||||||
|
session_path = _get_temp_session_path(session_timestamp)
|
||||||
|
|
||||||
|
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
env = Primaite(
|
||||||
|
training_config_path=TEST_CONFIG_ROOT
|
||||||
|
/ "one_node_states_on_off_main_config.yaml",
|
||||||
|
lay_down_config_path=data_manipulation_config_path(),
|
||||||
|
transaction_list=[],
|
||||||
|
session_path=session_path,
|
||||||
|
timestamp_str=timestamp_str,
|
||||||
|
)
|
||||||
|
# set red_agent_
|
||||||
|
env.training_config.random_red_agent = True
|
||||||
|
training_config = env.training_config
|
||||||
|
training_config.num_steps = env.episode_steps
|
||||||
|
|
||||||
|
run_generic(env, training_config)
|
||||||
|
# add red pol instructions to list
|
||||||
|
list_of_node_instructions.append(env.red_node_pol)
|
||||||
|
|
||||||
|
# compare instructions to make sure that red instructions are truly random
|
||||||
|
for index, instruction in enumerate(list_of_node_instructions):
|
||||||
|
for key in list_of_node_instructions[index].keys():
|
||||||
|
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
|
||||||
|
print(f"run {index}")
|
||||||
|
print(f"{key} start step: {instruction.get_start_step()}")
|
||||||
|
print(f"{key} end step: {instruction.get_end_step()}")
|
||||||
|
print(f"{key} target node id: {instruction.get_target_node_id()}")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])
|
||||||
Reference in New Issue
Block a user