1429 - created new branch from dev, added enums to enums.py, created agents package and utils.py file, added option to primaite_env.py for ANY action type and changed the action spaces are defined using ADSP branch

This commit is contained in:
SunilSamra
2023-05-26 10:17:45 +01:00
parent f7638ddb0c
commit 3cd5864f25
6 changed files with 735 additions and 19 deletions

View File

View File

@@ -0,0 +1,509 @@
import logging
import os.path
from datetime import datetime
import numpy as np
import yaml
from primaite.common.config_values_main import ConfigValuesMain
from primaite.common.enums import (
ActionType,
HardwareState,
LinkStatus,
NodeHardwareAction,
NodePOLType,
NodeSoftwareAction,
SoftwareState,
)
def load_config_values(config_path):
"""Loads the config values from the main config file into a config object."""
config_file_main = open(config_path, "r")
config_data = yaml.safe_load(config_file_main)
# Create a config class
config_values = ConfigValuesMain()
try:
# Generic
config_values.red_agent_identifier = config_data["redAgentIdentifier"]
config_values.action_type = ActionType[config_data["actionType"]]
config_values.config_filename_use_case = config_data["configFilename"]
# Reward values
# Generic
config_values.all_ok = float(config_data["allOk"])
# Node Operating State
config_values.off_should_be_on = float(config_data["offShouldBeOn"])
config_values.off_should_be_resetting = float(
config_data["offShouldBeResetting"]
)
config_values.on_should_be_off = float(config_data["onShouldBeOff"])
config_values.on_should_be_resetting = float(config_data["onShouldBeResetting"])
config_values.resetting_should_be_on = float(config_data["resettingShouldBeOn"])
config_values.resetting_should_be_off = float(
config_data["resettingShouldBeOff"]
)
# Node O/S or Service State
config_values.good_should_be_patching = float(
config_data["goodShouldBePatching"]
)
config_values.good_should_be_compromised = float(
config_data["goodShouldBeCompromised"]
)
config_values.good_should_be_overwhelmed = float(
config_data["goodShouldBeOverwhelmed"]
)
config_values.patching_should_be_good = float(
config_data["patchingShouldBeGood"]
)
config_values.patching_should_be_compromised = float(
config_data["patchingShouldBeCompromised"]
)
config_values.patching_should_be_overwhelmed = float(
config_data["patchingShouldBeOverwhelmed"]
)
config_values.compromised_should_be_good = float(
config_data["compromisedShouldBeGood"]
)
config_values.compromised_should_be_patching = float(
config_data["compromisedShouldBePatching"]
)
config_values.compromised_should_be_overwhelmed = float(
config_data["compromisedShouldBeOverwhelmed"]
)
config_values.compromised = float(config_data["compromised"])
config_values.overwhelmed_should_be_good = float(
config_data["overwhelmedShouldBeGood"]
)
config_values.overwhelmed_should_be_patching = float(
config_data["overwhelmedShouldBePatching"]
)
config_values.overwhelmed_should_be_compromised = float(
config_data["overwhelmedShouldBeCompromised"]
)
config_values.overwhelmed = float(config_data["overwhelmed"])
# IER status
config_values.red_ier_running = float(config_data["redIerRunning"])
config_values.green_ier_blocked = float(config_data["greenIerBlocked"])
# Patching / Reset durations
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
config_values.service_patching_duration = int(
config_data["servicePatchingDuration"]
)
except Exception as e:
print(f"Could not save load config data: {e} ")
return config_values
def configure_logging(log_name):
"""Configures logging."""
try:
now = datetime.now() # current date and time
time = now.strftime("%Y%m%d_%H%M%S")
filename = "/app/logs/" + log_name + "/" + time + ".log"
path = f"/app/logs/{log_name}"
is_dir = os.path.isdir(path)
if not is_dir:
os.makedirs(path)
logging.basicConfig(
filename=filename,
filemode="w",
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
level=logging.INFO,
)
except Exception as e:
print("ERROR: Could not start logging", e)
def transform_change_obs_readable(obs):
"""Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
os_states = [SoftwareState(i).name for i in obs[:, 2]]
new_obs = [ids, operating_states, os_states]
for service in range(3, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
]
new_obs.append(service_states)
return new_obs
def transform_obs_readable(obs):
"""
Transform obs readable function.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']].
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
# Convert list of tuples to list of lists
new_obs = [list(i) for i in new_obs]
return new_obs
def convert_to_new_obs(obs, num_nodes=10):
"""Convert original gym Box observation space to new multiDiscrete observation space."""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""
Convert to old observation, links filled with 0's as no information is included in new observation space.
example:
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
new_obs = array([[ 1, 1, 1, 1],
[ 2, 1, 1, 1],
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
# Add empty links back and add node ID back
s = np.zeros(
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
dtype=np.int64,
)
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
s[:num_nodes, 1:] = reshaped_nodes # put values back in
new_obs = s
# Add links back in
links = obs[-num_links:]
# Links will be added to the last protocol/service slot but they are not specific to that service
new_obs[num_nodes:, -1] = links
return new_obs
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""Return string describing change between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
list_of_changes = []
for n, row in enumerate(obs1 - obs2):
if row.any() != 0:
relevant_changes = np.where(row != 0, obs2[n], -1)
relevant_changes[0] = obs2[n, 0] # ID is always relevant
is_link = relevant_changes[0] > num_nodes
desc = _describe_obs_change_helper(relevant_changes, is_link)
list_of_changes.append(desc)
change_string = "\n ".join(list_of_changes)
if len(list_of_changes) > 0:
change_string = "\n " + change_string
return change_string
def _describe_obs_change_helper(obs_change, is_link):
"""
Helper funcion to describe what has changed.
example:
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
node_pol_types = [
NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3)
for i in index_changed
]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
if is_link
else HardwareState(obs_change[i]).name
if i == 1
else SoftwareState(obs_change[i]).name
for i in index_changed
]
if not is_link:
desc = f"ID {obs_change[0]}:"
for node_pol_type, state in list(zip(node_pol_types, states)):
desc = desc + " " + node_pol_type + " changed to " + state + "."
else:
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
return desc
def transform_action_node_enum(action):
"""
Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
if action[1] == "OPERATING":
property_action = NodeHardwareAction[action[2]].value
elif action[1] == "OS" or action[1] == "SERVICE":
property_action = NodeSoftwareAction[action[2]].value
else:
property_action = 0
action_service_index = action[3]
new_action = [
action_node_id,
action_node_property,
property_action,
action_service_index,
]
return new_action
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
2
] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
return new_action
def node_action_description(action):
"""Generate string describing a node-based action."""
if isinstance(action[1], (int, np.int64)):
# transform action to readable format
action = transform_action_node_readable(action)
node_id = action[0]
node_property = action[1]
property_action = action[2]
service_id = action[3]
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
elif node_property == "SERVICE":
description = (
f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
)
else:
return ""
return description
def transform_action_acl_readable(action):
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def transform_action_acl_enum(action):
"""Convert a acl action from readable string format, to enumerated format."""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == "ANY":
new_action[n + 2] = 0
new_action = np.array(new_action)
return new_action
def acl_action_description(action):
"""Generate string describing a acl-based action."""
if isinstance(action[0], (int, np.int64)):
# transform action to readable format
action = transform_action_acl_readable(action)
if action[0] == "NONE":
description = "NO ACL RULE APPLIED"
else:
description = (
f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]},"
f" for protocol/service index {action[4]} on port index {action[5]}"
)
return description
def get_node_of_ip(ip, node_dict):
"""
Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.get_ip_address()
if node_ip == ip:
return node_key
def is_valid_node_action(action):
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect.
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action."""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True
def get_new_action(old_action, action_dict):
"""Get new action (e.g. 32) from old action e.g. [1,1,1,0].
old_action can be either node or acl action type.
"""
for key, val in action_dict.items():
if list(val) == list(old_action):
return key
# Not all possible actions are included in dict, only valid action are
# if action is not in the dict, its an invalid action so return 0
return 0
def get_action_description(action, action_dict):
"""Get a string describing/explaining what an action is doing in words."""
action_array = action_dict[action]
if len(action_array) == 4:
# node actions have length 4
action_description = node_action_description(action_array)
elif len(action_array) == 6:
# acl actions have length 6
action_description = acl_action_description(action_array)
else:
# Should never happen
action_description = "Unrecognised action"
return action_description

View File

@@ -91,3 +91,29 @@ class FileSystemState(Enum):
DESTROYED = 3
REPAIRING = 4
RESTORING = 5
class NodeHardwareAction(Enum):
"""Node hardware action."""
NONE = 0
ON = 1
OFF = 2
RESET = 3
class NodeSoftwareAction(Enum):
"""Node software action."""
NONE = 0
PATCHING = 1
class LinkStatus(Enum):
"""Link traffic status."""
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
OVERLOAD = 4

View File

@@ -15,6 +15,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -232,15 +233,9 @@ class Primaite(Env):
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_space = spaces.MultiDiscrete(
[
self.num_nodes,
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
self.ACTION_SPACE_NODE_ACTION_VALUES,
self.num_services,
]
)
else:
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
@@ -249,16 +244,12 @@ class Primaite(Env):
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_space = spaces.MultiDiscrete(
[
self.ACTION_SPACE_ACL_ACTION_VALUES,
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
self.num_nodes + 1,
self.num_nodes + 1,
self.num_services + 1,
self.num_ports + 1,
]
)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info("Action space type ANY selected")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
# Set up a csv to store the results of the training
try:
@@ -1163,3 +1154,92 @@ class Primaite(Env):
else:
# Bad formatting
pass
def create_node_action_dict(self):
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
example return:
{0: [1, 0, 0, 0],
1: [1, 1, 1, 0],
2: [1, 1, 2, 0],
3: [1, 1, 3, 0],
4: [1, 2, 1, 0],
5: [1, 3, 1, 0],
...
}
"""
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
action_key = 1
for node in range(1, self.num_nodes + 1):
# 4 node properties (NONE, OPERATING, OS, SERVICE)
for node_property in range(4):
# Node Actions either:
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action, service_state]
# check to see if its a nothing aciton (has no effect)
if is_valid_node_action(action):
actions[action_key] = action
action_key += 1
return actions
def create_acl_action_dict(self):
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):
# 2 possible action permissions 0 = DENY, 1 = CREATE
for action_permission in range(2):
# Number of nodes + 1 (for any)
for source_ip in range(self.num_nodes + 1):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
]
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
return actions
def create_node_and_acl_action_dict(self):
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
The dictionary contains actions of both Node and ACL action types.
"""
node_action_dict = self.create_node_action_dict()
acl_action_dict = self.create_acl_action_dict()
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict

View File

@@ -0,0 +1,89 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: GENERIC
# Number of episodes to run per session
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1000000000
# Reward values
# Generic
allOk: 0
# Node Operating State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node O/S or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,12 @@
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def test_single_action_space():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_lay_down_config.yaml",
)
print("Average Reward:", env.average_reward)