|
|
|
|
@@ -1,14 +1,11 @@
|
|
|
|
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
|
|
|
|
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
|
|
|
|
import copy
|
|
|
|
|
import csv
|
|
|
|
|
import logging
|
|
|
|
|
import uuid as uuid
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Dict, Final, Tuple, Union
|
|
|
|
|
from random import choice, randint, sample, uniform
|
|
|
|
|
from typing import Dict, Tuple, Union
|
|
|
|
|
from typing import Dict, Final, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import networkx as nx
|
|
|
|
|
import numpy as np
|
|
|
|
|
@@ -321,11 +318,7 @@ class Primaite(Env):
|
|
|
|
|
link.clear_traffic()
|
|
|
|
|
|
|
|
|
|
# Create a Transaction (metric) object for this step
|
|
|
|
|
transaction = Transaction(
|
|
|
|
|
self.agent_identifier,
|
|
|
|
|
self.actual_episode_count,
|
|
|
|
|
self.step_count
|
|
|
|
|
)
|
|
|
|
|
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
|
|
|
|
# Load the initial observation space into the transaction
|
|
|
|
|
transaction.obs_space = self.obs_handler._flat_observation
|
|
|
|
|
|
|
|
|
|
@@ -436,12 +429,7 @@ class Primaite(Env):
|
|
|
|
|
for link_key, link_value in self.links.items():
|
|
|
|
|
_LOGGER.debug("Link ID: " + link_value.get_id())
|
|
|
|
|
for protocol in link_value.protocol_list:
|
|
|
|
|
print(
|
|
|
|
|
" Protocol: "
|
|
|
|
|
+ protocol.get_name().name
|
|
|
|
|
+ ", Load: "
|
|
|
|
|
+ str(protocol.get_load())
|
|
|
|
|
)
|
|
|
|
|
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
|
|
|
|
|
|
|
|
|
def interpret_action_and_apply(self, _action):
|
|
|
|
|
"""
|
|
|
|
|
@@ -456,13 +444,9 @@ class Primaite(Env):
|
|
|
|
|
self.apply_actions_to_nodes(_action)
|
|
|
|
|
elif self.training_config.action_type == ActionType.ACL:
|
|
|
|
|
self.apply_actions_to_acl(_action)
|
|
|
|
|
elif (
|
|
|
|
|
len(self.action_dict[_action]) == 6
|
|
|
|
|
): # ACL actions in multidiscrete form have len 6
|
|
|
|
|
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
|
|
|
|
self.apply_actions_to_acl(_action)
|
|
|
|
|
elif (
|
|
|
|
|
len(self.action_dict[_action]) == 4
|
|
|
|
|
): # Node actions in multdiscrete (array) from have len 4
|
|
|
|
|
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
|
|
|
|
self.apply_actions_to_nodes(_action)
|
|
|
|
|
else:
|
|
|
|
|
logging.error("Invalid action type found")
|
|
|
|
|
@@ -528,9 +512,7 @@ class Primaite(Env):
|
|
|
|
|
return
|
|
|
|
|
elif property_action == 1:
|
|
|
|
|
# Patch (valid action if it's good or compromised)
|
|
|
|
|
node.set_service_state(
|
|
|
|
|
self.services_list[service_index], SoftwareState.PATCHING
|
|
|
|
|
)
|
|
|
|
|
node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING)
|
|
|
|
|
else:
|
|
|
|
|
# Node is not of Service Type
|
|
|
|
|
return
|
|
|
|
|
@@ -1238,11 +1220,7 @@ class Primaite(Env):
|
|
|
|
|
|
|
|
|
|
# Change node keys to not overlap with acl keys
|
|
|
|
|
# Only 1 nothing action (key 0) is required, remove the other
|
|
|
|
|
new_node_action_dict = {
|
|
|
|
|
k + len(acl_action_dict) - 1: v
|
|
|
|
|
for k, v in node_action_dict.items()
|
|
|
|
|
if k != 0
|
|
|
|
|
}
|
|
|
|
|
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
|
|
|
|
|
|
|
|
|
|
# Combine the Node dict and ACL dict
|
|
|
|
|
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
|
|
|
|
@@ -1256,11 +1234,8 @@ class Primaite(Env):
|
|
|
|
|
|
|
|
|
|
# Decide how many nodes become compromised
|
|
|
|
|
node_list = list(self.nodes.values())
|
|
|
|
|
computers = [node for node in node_list if
|
|
|
|
|
node.node_type == NodeType.COMPUTER]
|
|
|
|
|
max_num_nodes_compromised = len(
|
|
|
|
|
computers
|
|
|
|
|
) # only computers can become compromised
|
|
|
|
|
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
|
|
|
|
max_num_nodes_compromised = len(computers) # only computers can become compromised
|
|
|
|
|
# random select between 1 and max_num_nodes_compromised
|
|
|
|
|
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
|
|
|
|
|
|
|
|
|
|
@@ -1271,9 +1246,7 @@ class Primaite(Env):
|
|
|
|
|
source_node = choice(nodes_to_be_compromised)
|
|
|
|
|
|
|
|
|
|
# For each of the nodes to be compromised decide which step they become compromised
|
|
|
|
|
max_step_compromised = (
|
|
|
|
|
self.episode_steps // 2
|
|
|
|
|
) # always compromise in first half of episode
|
|
|
|
|
max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode
|
|
|
|
|
|
|
|
|
|
# Bandwidth for all links
|
|
|
|
|
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
|
|
|
|
@@ -1283,15 +1256,13 @@ class Primaite(Env):
|
|
|
|
|
_LOGGER.error(msg)
|
|
|
|
|
raise Exception(msg)
|
|
|
|
|
|
|
|
|
|
servers = [node for node in node_list if
|
|
|
|
|
node.node_type == NodeType.SERVER]
|
|
|
|
|
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
|
|
|
|
|
|
|
|
|
|
for n, node in enumerate(nodes_to_be_compromised):
|
|
|
|
|
# 1: Use Node PoL to set node to compromised
|
|
|
|
|
|
|
|
|
|
_id = str(uuid.uuid4())
|
|
|
|
|
_start_step = randint(2,
|
|
|
|
|
max_step_compromised + 1) # step compromised
|
|
|
|
|
_start_step = randint(2, max_step_compromised + 1) # step compromised
|
|
|
|
|
pol_service_name = choice(list(node.services.keys()))
|
|
|
|
|
|
|
|
|
|
source_node_service = choice(list(source_node.services.values()))
|
|
|
|
|
@@ -1316,8 +1287,7 @@ class Primaite(Env):
|
|
|
|
|
|
|
|
|
|
ier_id = str(uuid.uuid4())
|
|
|
|
|
# Launch the attack after node is compromised, and not right at the end of the episode
|
|
|
|
|
ier_start_step = randint(_start_step + 2,
|
|
|
|
|
int(self.episode_steps * 0.8))
|
|
|
|
|
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
|
|
|
|
|
ier_end_step = self.episode_steps
|
|
|
|
|
|
|
|
|
|
# Randomise the load, as a percentage of a random link bandwith
|
|
|
|
|
@@ -1325,9 +1295,7 @@ class Primaite(Env):
|
|
|
|
|
ier_protocol = pol_service_name # Same protocol as compromised node
|
|
|
|
|
ier_service = node.services[pol_service_name]
|
|
|
|
|
ier_port = ier_service.port
|
|
|
|
|
ier_mission_criticality = (
|
|
|
|
|
0 # Red IER will never be important to green agent success
|
|
|
|
|
)
|
|
|
|
|
ier_mission_criticality = 0 # Red IER will never be important to green agent success
|
|
|
|
|
# We choose a node to attack based on the first that applies:
|
|
|
|
|
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
|
|
|
|
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
|
|
|
|
@@ -1340,16 +1308,15 @@ class Primaite(Env):
|
|
|
|
|
if len(possible_ier_destinations) < 1:
|
|
|
|
|
for server in servers:
|
|
|
|
|
if not self.acl.is_blocked(
|
|
|
|
|
node.ip_address,
|
|
|
|
|
server.ip_address,
|
|
|
|
|
ier_service,
|
|
|
|
|
ier_port,
|
|
|
|
|
node.ip_address,
|
|
|
|
|
server.ip_address,
|
|
|
|
|
ier_service,
|
|
|
|
|
ier_port,
|
|
|
|
|
):
|
|
|
|
|
possible_ier_destinations.append(server.node_id)
|
|
|
|
|
if len(possible_ier_destinations) < 1:
|
|
|
|
|
# If still none found choose from all servers
|
|
|
|
|
possible_ier_destinations = [server.node_id for server in
|
|
|
|
|
servers]
|
|
|
|
|
possible_ier_destinations = [server.node_id for server in servers]
|
|
|
|
|
ier_dest = choice(possible_ier_destinations)
|
|
|
|
|
self.red_iers[ier_id] = IER(
|
|
|
|
|
ier_id,
|
|
|
|
|
|