feature\1522:

Create random red agent behaviour.
This commit is contained in:
Brian Kanyora
2023-06-22 15:34:13 +01:00
parent 9b0e24c27b
commit e0f3d61f65
5 changed files with 356 additions and 21 deletions

View File

@@ -1,7 +1,7 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Union, Optional
from typing import Any, Dict, Final, Optional, Union
import yaml
@@ -21,6 +21,9 @@ class TrainingConfig:
agent_identifier: str = "STABLE_BASELINES3_A2C"
"The Red Agent algo/class to be used."
red_agent_identifier: str = "RANDOM"
"Creates Random Red Agent Attacks"
action_type: ActionType = ActionType.ANY
"The ActionType to use."
@@ -167,8 +170,7 @@ def main_training_config_path() -> Path:
return path
def load(file_path: Union[str, Path],
legacy_file: bool = False) -> TrainingConfig:
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
@@ -213,9 +215,7 @@ def load(file_path: Union[str, Path],
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
@@ -227,10 +227,7 @@ def convert_legacy_training_config_dict(
don't have action_type values.
:return: The converted training config dict.
"""
config_dict = {
"num_steps": num_steps,
"action_type": action_type
}
config_dict = {"num_steps": num_steps, "action_type": action_type}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:

View File

@@ -14,8 +14,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, \
is_valid_node_action
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -24,8 +23,9 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState, ObservationType,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config import training_config
@@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import \
NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
apply_red_agent_node_pol
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
@@ -177,7 +175,6 @@ class Primaite(Env):
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
with open(self._lay_down_config_path, "r") as file:
# Open the config file and build the environment laydown
@@ -238,7 +235,9 @@ class Primaite(Env):
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
_LOGGER.info(
f"Invalid action type selected: {self.training_config.action_type}"
)
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
@@ -275,6 +274,10 @@ class Primaite(Env):
# Does this for both live and reference nodes
self.reset_environment()
# Create a random red agent to use for this episode
if self.training_config.red_agent_identifier == "RANDOM":
self.create_random_red_agent()
# Reset counters and totals
self.total_reward = 0
self.step_count = 0
@@ -379,7 +382,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
#print(f" Step {self.step_count} Reward: {str(reward)}")
print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -1033,7 +1036,6 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
"""
Extracts action_info.
@@ -1216,3 +1218,152 @@ class Primaite(Env):
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict
def create_random_red_agent(self):
"""Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol
self.red_iers = {}
self.red_node_pol = {}
# Decide how many nodes become compromised
node_list = list(self.nodes.values())
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
max_num_nodes_compromised = len(
computers
) # only computers can become compromised
# random select between 1 and max_num_nodes_compromised
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1)
# Decide which of the nodes to compromise
nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise)
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
) # always compromise in first half of episode
# Bandwidth for all links
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
for n, node in enumerate(nodes_to_be_compromised):
# 1: Use Node PoL to set node to compromised
_id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate
_start_step = np.random.randint(
2, max_step_compromised + 1
) # step compromised
_end_step = _start_step # Become compromised on 1 step
_target_node_id = node.node_id
_pol_initiator = "DIRECT"
_pol_type = NodePOLType["SERVICE"] # All computers are service nodes
pol_service_name = np.random.choice(
list(node.get_services().keys())
) # Random service may wish to change this, currently always TCP)
pol_protocol = pol_protocol
_pol_state = SoftwareState.COMPROMISED
is_entry_node = True # Assumes all computers in network are entry nodes
_pol_source_node_id = _pol_source_node_id
_pol_source_node_service = _pol_source_node_service
_pol_source_node_service_state = _pol_source_node_service_state
red_pol = NodeStateInstructionRed(
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_pol_type,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
)
self.red_node_pol[_id] = red_pol
# 2: Launch the attack from compromised node - set the IER
ier_id = str(2000 + n)
# Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = np.random.randint(
_start_step + 2, int(self.episode_steps * 0.8)
)
ier_end_step = self.episode_steps
ier_source_node_id = node.get_id()
# Randomise the load, as a percentage of a random link bandwith
ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice(
bandwidths
)
ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.get_services()[
pol_service_name
] # same service as defined in the pol
ier_port = ier_service.get_port()
ier_mission_criticality = (
0 # Red IER will never be important to green agent success
)
# We choose a node to attack based on the first that applies:
# a. Green IERs, select dest node of the red ier based on dest node of green IER
# b. Attack a random server that doesn't have a DENY acl rule in default config
# c. Attack a random server
possible_ier_destinations = [
ier.get_dest_node_id()
for ier in list(self.green_iers.values())
if ier.get_source_node_id() == node.get_id()
]
if len(possible_ier_destinations) < 1:
for server in servers:
if not self.acl.is_blocked(
node.get_ip_address(),
server.ip_address,
ier_service,
ier_port,
):
possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1:
# If still none found choose from all servers
possible_ier_destinations = [server.node_id for server in servers]
ier_dest = np.random.choice(possible_ier_destinations)
self.red_iers[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source_node_id,
ier_dest,
ier_mission_criticality,
)
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
# TODO remove duplicate red pol for same targetted service - must take into account start step
o_pol_id = str(3000 + n)
o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
o_pol_end_step = (
self.episode_steps
) # Can become compromised at any timestep after start
o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
o_pol_service_name = (
ier_protocol # Same protocol/service as the IER uses to attack
)
o_pol_new_state = SoftwareState["OVERWHELMED"]
o_pol_entry_node = False # Assumes servers are not entry nodes
o_red_pol = NodeStateInstructionRed(
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_pol_type,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
)
self.red_node_pol[o_pol_id] = o_red_pol

View File

@@ -137,3 +137,20 @@ class NodeStateInstructionRed(object):
The source node service state
"""
return self.source_node_service_state
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"start_step={self.start_step}, "
f"end_step={self.end_step}, "
f"target_node_id={self.target_node_id}, "
f"initiator={self.initiator}, "
f"pol_type={self.pol_type}, "
f"service_name={self.service_name}, "
f"state={self.state}, "
f"source_node_id={self.source_node_id}, "
f"source_node_service={self.source_node_service}, "
f"source_node_service_state={self.source_node_service_state}"
f")"
)

View File

@@ -0,0 +1,96 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: GENERIC
#
red_agent_identifier: RANDOM
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: ANY
# Number of episodes to run per session
num_episodes: 1
# Number of time_steps per episode
num_steps: 5
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1_000_000_000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,74 @@
from datetime import time, datetime
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_temp_session_path
def run_generic(env, config_values):
"""Run against a generic agent."""
# Reset the environment at the start of the episode
env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.reset()
env.close()
def test_random_red_agent_behaviour():
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
list_of_node_instructions = []
for i in range(2):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
transaction_list=[],
session_path=session_path,
timestamp_str=timestamp_str,
)
training_config = env.training_config
training_config.num_steps = env.episode_steps
# TOOD: This needs t be refactored to happen outside. Should be part of
# a main Session class.
if training_config.agent_identifier == "GENERIC":
run_generic(env, training_config)
all_red_actions = env.red_node_pol
list_of_node_instructions.append(all_red_actions)
# assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1]))
print(list_of_node_instructions[0]["1"].get_start_step())
print(list_of_node_instructions[0]["1"].get_end_step())
print(list_of_node_instructions[0]["1"].get_target_node_id())
print(list_of_node_instructions[1]["1"].get_start_step())
print(list_of_node_instructions[1]["1"].get_end_step())
print(list_of_node_instructions[1]["1"].get_target_node_id())
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])