#917 - Synced with dev (at the point of random red agent)

This commit is contained in:
Chris McCarthy
2023-07-03 17:25:21 +01:00
parent 72b0050b1b
commit c36ddfa03f
4 changed files with 172 additions and 84 deletions

View File

@@ -7,8 +7,10 @@
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# RED AGENT IDENTIFIER
# RANDOM or NONE
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: True
# Sets How the Action Space is defined:

View File

@@ -427,7 +427,12 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
_LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
print(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
+ str(protocol.get_load())
)
def interpret_action_and_apply(self, _action):
"""
@@ -437,16 +442,21 @@ class Primaite(Env):
_action: The action space from the agent
"""
# At the moment, actions are only affecting nodes
if self.training_config.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
_LOGGER.error("Invalid action type found")
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
"""
@@ -510,8 +520,7 @@ class Primaite(Env):
elif property_action == 1:
# Patch (valid action if it's good or compromised)
node.set_service_state(
self.services_list[service_index],
SoftwareState.PATCHING,
self.services_list[service_index], SoftwareState.PATCHING
)
else:
# Node is not of Service Type
@@ -709,7 +718,8 @@ class Primaite(Env):
_LOGGER.error(f"Invalid item_type: {item_type}")
pass
_LOGGER.debug("Environment configuration loaded")
_LOGGER.info("Environment configuration loaded")
print("Environment configuration loaded")
def create_node(self, item):
"""
@@ -1166,12 +1176,7 @@ class Primaite(Env):
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [
node,
node_property,
node_action,
service_state,
]
action = [node, node_property, node_action, service_state]
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
actions[action_key] = action
@@ -1221,7 +1226,11 @@ class Primaite(Env):
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
@@ -1235,7 +1244,8 @@ class Primaite(Env):
# Decide how many nodes become compromised
node_list = list(self.nodes.values())
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
computers = [node for node in node_list if
node.node_type == NodeType.COMPUTER]
max_num_nodes_compromised = len(
computers
) # only computers can become compromised
@@ -1250,7 +1260,7 @@ class Primaite(Env):
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
self.episode_steps // 2
) # always compromise in first half of episode
# Bandwidth for all links
@@ -1261,13 +1271,15 @@ class Primaite(Env):
_LOGGER.error(msg)
raise Exception(msg)
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
servers = [node for node in node_list if
node.node_type == NodeType.SERVER]
for n, node in enumerate(nodes_to_be_compromised):
# 1: Use Node PoL to set node to compromised
_id = str(uuid.uuid4())
_start_step = randint(2, max_step_compromised + 1) # step compromised
_start_step = randint(2,
max_step_compromised + 1) # step compromised
pol_service_name = choice(list(node.services.keys()))
source_node_service = choice(list(source_node.services.values()))
@@ -1292,7 +1304,8 @@ class Primaite(Env):
ier_id = str(uuid.uuid4())
# Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
ier_start_step = randint(_start_step + 2,
int(self.episode_steps * 0.8))
ier_end_step = self.episode_steps
# Randomise the load, as a percentage of a random link bandwith
@@ -1315,15 +1328,16 @@ class Primaite(Env):
if len(possible_ier_destinations) < 1:
for server in servers:
if not self.acl.is_blocked(
node.ip_address,
server.ip_address,
ier_service,
ier_port,
node.ip_address,
server.ip_address,
ier_service,
ier_port,
):
possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1:
# If still none found choose from all servers
possible_ier_destinations = [server.node_id for server in servers]
possible_ier_destinations = [server.node_id for server in
servers]
ier_dest = choice(possible_ier_destinations)
self.red_iers[ier_id] = IER(
ier_id,

View File

@@ -0,0 +1,112 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: CUSTOM
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: DUMMY
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: True
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 2
# Number of time_steps per episode
num_steps: 15
# Time delay between steps (for generic agents)
time_delay: 1
# Type of session to be run (TRAINING or EVALUATION)
session_type: EVAL
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,68 +1,29 @@
from datetime import datetime
import pytest
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.environment.primaite_env import Primaite
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_temp_session_path
def run_generic(env, config_values):
"""Run against a generic agent."""
# Reset the environment at the start of the episode
env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Reset the environment at the end of the episode
env.reset()
env.close()
def test_random_red_agent_behaviour():
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT / "test_random_red_main_config.yaml",
data_manipulation_config_path(),
]
],
indirect=True,
)
def test_random_red_agent_behaviour(temp_primaite_session):
"""Test that red agent POL is randomised each episode."""
list_of_node_instructions = []
# RUN TWICE so we can make sure that red agent is randomised
for i in range(2):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
with temp_primaite_session as session:
session.evaluate()
list_of_node_instructions.append(session.env.red_node_pol)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=data_manipulation_config_path(),
transaction_list=[],
session_path=session_path,
timestamp_str=timestamp_str,
)
# set red_agent_
env.training_config.random_red_agent = True
training_config = env.training_config
training_config.num_steps = env.episode_steps
run_generic(env, training_config)
# add red pol instructions to list
list_of_node_instructions.append(env.red_node_pol)
session.evaluate()
list_of_node_instructions.append(session.env.red_node_pol)
# compare instructions to make sure that red instructions are truly random
for index, instruction in enumerate(list_of_node_instructions):
@@ -73,5 +34,4 @@ def test_random_red_agent_behaviour():
print(f"{key} end step: {instruction.get_end_step()}")
print(f"{key} target node id: {instruction.get_target_node_id()}")
print("")
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])