#1522: fixing create random red agent function

This commit is contained in:
Czar Echavez
2023-06-29 15:03:11 +01:00
parent 15b3bad5d4
commit 10e432eb01
2 changed files with 78 additions and 73 deletions

2
.gitignore vendored
View File

@@ -137,3 +137,5 @@ dmypy.json
# Cython debug symbols
cython_debug/
.idea/

View File

@@ -9,6 +9,7 @@ from typing import Dict, Tuple, Union
import networkx as nx
import numpy as np
import uuid as uuid
import yaml
from gym import Env, spaces
from matplotlib import pyplot as plt
@@ -58,12 +59,12 @@ class Primaite(Env):
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
transaction_list,
session_path: Path,
timestamp_str: str,
):
"""
The Primaite constructor.
@@ -275,8 +276,8 @@ class Primaite(Env):
self.reset_environment()
# Create a random red agent to use for this episode
if self.training_config.red_agent_identifier == "RANDOM":
self.create_random_red_agent()
# if self.training_config.red_agent_identifier == "RANDOM":
# self.create_random_red_agent()
# Reset counters and totals
self.total_reward = 0
@@ -380,6 +381,7 @@ class Primaite(Env):
self.nodes_post_pol,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
self.green_iers_reference,
self.red_iers,
self.step_count,
@@ -445,11 +447,11 @@ class Primaite(Env):
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
@@ -1247,14 +1249,17 @@ class Primaite(Env):
computers
) # only computers can become compromised
# random select between 1 and max_num_nodes_compromised
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1)
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised)
# Decide which of the nodes to compromise
nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise)
# choose a random compromise node to be source of attacks
source_node = np.random.choice(nodes_to_be_compromised, 1)[0]
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
self.episode_steps // 2
) # always compromise in first half of episode
# Bandwidth for all links
@@ -1264,57 +1269,50 @@ class Primaite(Env):
for n, node in enumerate(nodes_to_be_compromised):
# 1: Use Node PoL to set node to compromised
_id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate
_id = str(uuid.uuid4())
_start_step = np.random.randint(
2, max_step_compromised + 1
) # step compromised
_end_step = _start_step # Become compromised on 1 step
_target_node_id = node.node_id
_pol_initiator = "DIRECT"
_pol_type = NodePOLType["SERVICE"] # All computers are service nodes
pol_service_name = np.random.choice(
list(node.get_services().keys())
) # Random service may wish to change this, currently always TCP)
pol_protocol = pol_protocol
_pol_state = SoftwareState.COMPROMISED
is_entry_node = True # Assumes all computers in network are entry nodes
_pol_source_node_id = _pol_source_node_id
_pol_source_node_service = _pol_source_node_service
_pol_source_node_service_state = _pol_source_node_service_state
list(node.services.keys())
)
source_node_service = np.random.choice(
list(source_node.services.values())
)
red_pol = NodeStateInstructionRed(
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_pol_type,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
_id=_id,
_start_step=_start_step,
_end_step=_start_step, # only run for 1 step
_target_node_id=node.node_id,
_pol_initiator="DIRECT",
_pol_type=NodePOLType["SERVICE"],
pol_protocol=pol_service_name,
_pol_state=SoftwareState.COMPROMISED,
_pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state
)
self.red_node_pol[_id] = red_pol
# 2: Launch the attack from compromised node - set the IER
ier_id = str(2000 + n)
ier_id = str(uuid.uuid4())
# Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = np.random.randint(
_start_step + 2, int(self.episode_steps * 0.8)
)
ier_end_step = self.episode_steps
ier_source_node_id = node.get_id()
# Randomise the load, as a percentage of a random link bandwith
ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice(
bandwidths
)
ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.get_services()[
pol_service_name
] # same service as defined in the pol
ier_port = ier_service.get_port()
ier_service = node.services[pol_service_name]
ier_port = ier_service.port
ier_mission_criticality = (
0 # Red IER will never be important to green agent success
)
@@ -1325,15 +1323,15 @@ class Primaite(Env):
possible_ier_destinations = [
ier.get_dest_node_id()
for ier in list(self.green_iers.values())
if ier.get_source_node_id() == node.get_id()
if ier.get_source_node_id() == node.node_id
]
if len(possible_ier_destinations) < 1:
for server in servers:
if not self.acl.is_blocked(
node.get_ip_address(),
server.ip_address,
ier_service,
ier_port,
node.get_ip_address(),
server.ip_address,
ier_service,
ier_port,
):
possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1:
@@ -1347,37 +1345,42 @@ class Primaite(Env):
ier_load,
ier_protocol,
ier_port,
ier_source_node_id,
node.node_id,
ier_dest,
ier_mission_criticality,
)
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
# TODO remove duplicate red pol for same targetted service - must take into account start step
overwhelm_pol = red_pol
overwhelm_pol.id = str(uuid.uuid4())
overwhelm_pol.end_step = self.episode_steps
o_pol_id = str(3000 + n)
o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
o_pol_end_step = (
self.episode_steps
) # Can become compromised at any timestep after start
o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
o_pol_service_name = (
ier_protocol # Same protocol/service as the IER uses to attack
)
o_pol_new_state = SoftwareState["OVERWHELMED"]
o_pol_entry_node = False # Assumes servers are not entry nodes
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
# # TODO remove duplicate red pol for same targetted service - must take into account start step
#
o_pol_id = str(uuid.uuid4())
# o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
# o_pol_end_step = (
# self.episode_steps
# ) # Can become compromised at any timestep after start
# o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
# o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
# o_pol_service_name = (
# ier_protocol # Same protocol/service as the IER uses to attack
# )
# o_pol_new_state = SoftwareState["OVERWHELMED"]
# o_pol_entry_node = False # Assumes servers are not entry nodes
o_red_pol = NodeStateInstructionRed(
_id,
_start_step,
_end_step,
_target_node_id,
_pol_initiator,
_pol_type,
pol_protocol,
_pol_state,
_pol_source_node_id,
_pol_source_node_service,
_pol_source_node_service_state,
_id=o_pol_id,
_start_step=ier_start_step,
_end_step=self.episode_steps,
_target_node_id=ier_dest,
_pol_initiator="DIRECT",
_pol_type=NodePOLType["SERVICE"],
pol_protocol=ier_protocol,
_pol_state=SoftwareState.OVERWHELMED,
_pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state
)
self.red_node_pol[o_pol_id] = o_red_pol