feature\1522:
Create random red agent behaviour.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union, Optional
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -21,6 +21,9 @@ class TrainingConfig:
|
||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||
"The Red Agent algo/class to be used."
|
||||
|
||||
red_agent_identifier: str = "RANDOM"
|
||||
"Creates Random Red Agent Attacks"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
|
||||
@@ -167,8 +170,7 @@ def main_training_config_path() -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path],
|
||||
legacy_file: bool = False) -> TrainingConfig:
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
@@ -213,9 +215,7 @@ def load(file_path: Union[str, Path],
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
num_steps: int = 256,
|
||||
action_type: str = "ANY"
|
||||
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
@@ -227,10 +227,7 @@ def convert_legacy_training_config_dict(
|
||||
don't have action_type values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"num_steps": num_steps,
|
||||
"action_type": action_type
|
||||
}
|
||||
config_dict = {"num_steps": num_steps, "action_type": action_type}
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
|
||||
@@ -14,8 +14,7 @@ from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
||||
is_valid_node_action
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
@@ -24,8 +23,9 @@ from primaite.common.enums import (
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
ObservationType,
|
||||
Priority,
|
||||
SoftwareState, ObservationType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
@@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import \
|
||||
NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
||||
apply_red_agent_node_pol
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -177,7 +175,6 @@ class Primaite(Env):
|
||||
# It will be initialised later.
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -238,7 +235,9 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
_LOGGER.info(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
# Set up a csv to store the results of the training
|
||||
try:
|
||||
header = ["Episode", "Average Reward"]
|
||||
@@ -275,6 +274,10 @@ class Primaite(Env):
|
||||
# Does this for both live and reference nodes
|
||||
self.reset_environment()
|
||||
|
||||
# Create a random red agent to use for this episode
|
||||
if self.training_config.red_agent_identifier == "RANDOM":
|
||||
self.create_random_red_agent()
|
||||
|
||||
# Reset counters and totals
|
||||
self.total_reward = 0
|
||||
self.step_count = 0
|
||||
@@ -379,7 +382,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -1033,7 +1036,6 @@ class Primaite(Env):
|
||||
"""
|
||||
self.observation_type = ObservationType[observation_info["type"]]
|
||||
|
||||
|
||||
def get_action_info(self, action_info):
|
||||
"""
|
||||
Extracts action_info.
|
||||
@@ -1216,3 +1218,152 @@ class Primaite(Env):
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
return combined_action_dict
|
||||
|
||||
def create_random_red_agent(self):
|
||||
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||
|
||||
# Reset the current red iers and red node pol
|
||||
self.red_iers = {}
|
||||
self.red_node_pol = {}
|
||||
|
||||
# Decide how many nodes become compromised
|
||||
node_list = list(self.nodes.values())
|
||||
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
||||
max_num_nodes_compromised = len(
|
||||
computers
|
||||
) # only computers can become compromised
|
||||
# random select between 1 and max_num_nodes_compromised
|
||||
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1)
|
||||
|
||||
# Decide which of the nodes to compromise
|
||||
nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise)
|
||||
|
||||
# For each of the nodes to be compromised decide which step they become compromised
|
||||
max_step_compromised = (
|
||||
self.episode_steps // 2
|
||||
) # always compromise in first half of episode
|
||||
|
||||
# Bandwidth for all links
|
||||
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
||||
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
|
||||
|
||||
for n, node in enumerate(nodes_to_be_compromised):
|
||||
# 1: Use Node PoL to set node to compromised
|
||||
|
||||
_id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate
|
||||
_start_step = np.random.randint(
|
||||
2, max_step_compromised + 1
|
||||
) # step compromised
|
||||
_end_step = _start_step # Become compromised on 1 step
|
||||
_target_node_id = node.node_id
|
||||
_pol_initiator = "DIRECT"
|
||||
_pol_type = NodePOLType["SERVICE"] # All computers are service nodes
|
||||
pol_service_name = np.random.choice(
|
||||
list(node.get_services().keys())
|
||||
) # Random service may wish to change this, currently always TCP)
|
||||
pol_protocol = pol_protocol
|
||||
_pol_state = SoftwareState.COMPROMISED
|
||||
is_entry_node = True # Assumes all computers in network are entry nodes
|
||||
_pol_source_node_id = _pol_source_node_id
|
||||
_pol_source_node_service = _pol_source_node_service
|
||||
_pol_source_node_service_state = _pol_source_node_service_state
|
||||
red_pol = NodeStateInstructionRed(
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_pol_type,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
)
|
||||
|
||||
self.red_node_pol[_id] = red_pol
|
||||
|
||||
# 2: Launch the attack from compromised node - set the IER
|
||||
|
||||
ier_id = str(2000 + n)
|
||||
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||
ier_start_step = np.random.randint(
|
||||
_start_step + 2, int(self.episode_steps * 0.8)
|
||||
)
|
||||
ier_end_step = self.episode_steps
|
||||
ier_source_node_id = node.get_id()
|
||||
# Randomise the load, as a percentage of a random link bandwith
|
||||
ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice(
|
||||
bandwidths
|
||||
)
|
||||
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||
ier_service = node.get_services()[
|
||||
pol_service_name
|
||||
] # same service as defined in the pol
|
||||
ier_port = ier_service.get_port()
|
||||
ier_mission_criticality = (
|
||||
0 # Red IER will never be important to green agent success
|
||||
)
|
||||
# We choose a node to attack based on the first that applies:
|
||||
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
||||
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
||||
# c. Attack a random server
|
||||
possible_ier_destinations = [
|
||||
ier.get_dest_node_id()
|
||||
for ier in list(self.green_iers.values())
|
||||
if ier.get_source_node_id() == node.get_id()
|
||||
]
|
||||
if len(possible_ier_destinations) < 1:
|
||||
for server in servers:
|
||||
if not self.acl.is_blocked(
|
||||
node.get_ip_address(),
|
||||
server.ip_address,
|
||||
ier_service,
|
||||
ier_port,
|
||||
):
|
||||
possible_ier_destinations.append(server.node_id)
|
||||
if len(possible_ier_destinations) < 1:
|
||||
# If still none found choose from all servers
|
||||
possible_ier_destinations = [server.node_id for server in servers]
|
||||
ier_dest = np.random.choice(possible_ier_destinations)
|
||||
self.red_iers[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source_node_id,
|
||||
ier_dest,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
||||
# TODO remove duplicate red pol for same targetted service - must take into account start step
|
||||
|
||||
o_pol_id = str(3000 + n)
|
||||
o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
|
||||
o_pol_end_step = (
|
||||
self.episode_steps
|
||||
) # Can become compromised at any timestep after start
|
||||
o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
|
||||
o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
|
||||
o_pol_service_name = (
|
||||
ier_protocol # Same protocol/service as the IER uses to attack
|
||||
)
|
||||
o_pol_new_state = SoftwareState["OVERWHELMED"]
|
||||
o_pol_entry_node = False # Assumes servers are not entry nodes
|
||||
o_red_pol = NodeStateInstructionRed(
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_pol_type,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
)
|
||||
self.red_node_pol[o_pol_id] = o_red_pol
|
||||
|
||||
@@ -137,3 +137,20 @@ class NodeStateInstructionRed(object):
|
||||
The source node service state
|
||||
"""
|
||||
return self.source_node_service_state
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"id={self.id}, "
|
||||
f"start_step={self.start_step}, "
|
||||
f"end_step={self.end_step}, "
|
||||
f"target_node_id={self.target_node_id}, "
|
||||
f"initiator={self.initiator}, "
|
||||
f"pol_type={self.pol_type}, "
|
||||
f"service_name={self.service_name}, "
|
||||
f"state={self.state}, "
|
||||
f"source_node_id={self.source_node_id}, "
|
||||
f"source_node_service={self.source_node_service}, "
|
||||
f"source_node_service_state={self.source_node_service_state}"
|
||||
f")"
|
||||
)
|
||||
96
tests/config/random_agent_main_config.yaml
Normal file
96
tests/config/random_agent_main_config.yaml
Normal file
@@ -0,0 +1,96 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: GENERIC
|
||||
#
|
||||
red_agent_identifier: RANDOM
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 1
|
||||
# Number of time_steps per episode
|
||||
num_steps: 5
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 1
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
74
tests/test_red_random_agent_behaviour.py
Normal file
74
tests/test_red_random_agent_behaviour.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import time, datetime
|
||||
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_temp_session_path
|
||||
|
||||
|
||||
def run_generic(env, config_values):
|
||||
"""Run against a generic agent."""
|
||||
# Reset the environment at the start of the episode
|
||||
env.reset()
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
# action = env.action_space.sample()
|
||||
action = 0
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test_random_red_agent_behaviour():
|
||||
"""
|
||||
Test that hardware state is penalised at each step.
|
||||
|
||||
When the initial state is OFF compared to reference state which is ON.
|
||||
"""
|
||||
list_of_node_instructions = []
|
||||
for i in range(2):
|
||||
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
session_timestamp: datetime = datetime.now()
|
||||
session_path = _get_temp_session_path(session_timestamp)
|
||||
|
||||
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
env = Primaite(
|
||||
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
|
||||
transaction_list=[],
|
||||
session_path=session_path,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
training_config = env.training_config
|
||||
training_config.num_steps = env.episode_steps
|
||||
|
||||
# TOOD: This needs t be refactored to happen outside. Should be part of
|
||||
# a main Session class.
|
||||
if training_config.agent_identifier == "GENERIC":
|
||||
run_generic(env, training_config)
|
||||
all_red_actions = env.red_node_pol
|
||||
list_of_node_instructions.append(all_red_actions)
|
||||
|
||||
# assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1]))
|
||||
print(list_of_node_instructions[0]["1"].get_start_step())
|
||||
print(list_of_node_instructions[0]["1"].get_end_step())
|
||||
print(list_of_node_instructions[0]["1"].get_target_node_id())
|
||||
print(list_of_node_instructions[1]["1"].get_start_step())
|
||||
print(list_of_node_instructions[1]["1"].get_end_step())
|
||||
print(list_of_node_instructions[1]["1"].get_target_node_id())
|
||||
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])
|
||||
Reference in New Issue
Block a user