#1522: fixing create random red agent function
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -137,3 +137,5 @@ dmypy.json
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
.idea/
|
||||
@@ -9,6 +9,7 @@ from typing import Dict, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import uuid as uuid
|
||||
import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
@@ -58,12 +59,12 @@ class Primaite(Env):
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
transaction_list,
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
):
|
||||
"""
|
||||
The Primaite constructor.
|
||||
@@ -275,8 +276,8 @@ class Primaite(Env):
|
||||
self.reset_environment()
|
||||
|
||||
# Create a random red agent to use for this episode
|
||||
if self.training_config.red_agent_identifier == "RANDOM":
|
||||
self.create_random_red_agent()
|
||||
# if self.training_config.red_agent_identifier == "RANDOM":
|
||||
# self.create_random_red_agent()
|
||||
|
||||
# Reset counters and totals
|
||||
self.total_reward = 0
|
||||
@@ -380,6 +381,7 @@ class Primaite(Env):
|
||||
self.nodes_post_pol,
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
@@ -445,11 +447,11 @@ class Primaite(Env):
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
@@ -1247,14 +1249,17 @@ class Primaite(Env):
|
||||
computers
|
||||
) # only computers can become compromised
|
||||
# random select between 1 and max_num_nodes_compromised
|
||||
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1)
|
||||
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised)
|
||||
|
||||
# Decide which of the nodes to compromise
|
||||
nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise)
|
||||
|
||||
# choose a random compromise node to be source of attacks
|
||||
source_node = np.random.choice(nodes_to_be_compromised, 1)[0]
|
||||
|
||||
# For each of the nodes to be compromised decide which step they become compromised
|
||||
max_step_compromised = (
|
||||
self.episode_steps // 2
|
||||
self.episode_steps // 2
|
||||
) # always compromise in first half of episode
|
||||
|
||||
# Bandwidth for all links
|
||||
@@ -1264,57 +1269,50 @@ class Primaite(Env):
|
||||
for n, node in enumerate(nodes_to_be_compromised):
|
||||
# 1: Use Node PoL to set node to compromised
|
||||
|
||||
_id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate
|
||||
_id = str(uuid.uuid4())
|
||||
_start_step = np.random.randint(
|
||||
2, max_step_compromised + 1
|
||||
) # step compromised
|
||||
_end_step = _start_step # Become compromised on 1 step
|
||||
_target_node_id = node.node_id
|
||||
_pol_initiator = "DIRECT"
|
||||
_pol_type = NodePOLType["SERVICE"] # All computers are service nodes
|
||||
pol_service_name = np.random.choice(
|
||||
list(node.get_services().keys())
|
||||
) # Random service may wish to change this, currently always TCP)
|
||||
pol_protocol = pol_protocol
|
||||
_pol_state = SoftwareState.COMPROMISED
|
||||
is_entry_node = True # Assumes all computers in network are entry nodes
|
||||
_pol_source_node_id = _pol_source_node_id
|
||||
_pol_source_node_service = _pol_source_node_service
|
||||
_pol_source_node_service_state = _pol_source_node_service_state
|
||||
list(node.services.keys())
|
||||
)
|
||||
|
||||
source_node_service = np.random.choice(
|
||||
list(source_node.services.values())
|
||||
)
|
||||
|
||||
red_pol = NodeStateInstructionRed(
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_pol_type,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
_id=_id,
|
||||
_start_step=_start_step,
|
||||
_end_step=_start_step, # only run for 1 step
|
||||
_target_node_id=node.node_id,
|
||||
_pol_initiator="DIRECT",
|
||||
_pol_type=NodePOLType["SERVICE"],
|
||||
pol_protocol=pol_service_name,
|
||||
_pol_state=SoftwareState.COMPROMISED,
|
||||
_pol_source_node_id=source_node.node_id,
|
||||
_pol_source_node_service=source_node_service.name,
|
||||
_pol_source_node_service_state=source_node_service.software_state
|
||||
)
|
||||
|
||||
self.red_node_pol[_id] = red_pol
|
||||
|
||||
# 2: Launch the attack from compromised node - set the IER
|
||||
|
||||
ier_id = str(2000 + n)
|
||||
ier_id = str(uuid.uuid4())
|
||||
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||
ier_start_step = np.random.randint(
|
||||
_start_step + 2, int(self.episode_steps * 0.8)
|
||||
)
|
||||
ier_end_step = self.episode_steps
|
||||
ier_source_node_id = node.get_id()
|
||||
|
||||
# Randomise the load, as a percentage of a random link bandwith
|
||||
ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice(
|
||||
bandwidths
|
||||
)
|
||||
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||
ier_service = node.get_services()[
|
||||
pol_service_name
|
||||
] # same service as defined in the pol
|
||||
ier_port = ier_service.get_port()
|
||||
ier_service = node.services[pol_service_name]
|
||||
ier_port = ier_service.port
|
||||
ier_mission_criticality = (
|
||||
0 # Red IER will never be important to green agent success
|
||||
)
|
||||
@@ -1325,15 +1323,15 @@ class Primaite(Env):
|
||||
possible_ier_destinations = [
|
||||
ier.get_dest_node_id()
|
||||
for ier in list(self.green_iers.values())
|
||||
if ier.get_source_node_id() == node.get_id()
|
||||
if ier.get_source_node_id() == node.node_id
|
||||
]
|
||||
if len(possible_ier_destinations) < 1:
|
||||
for server in servers:
|
||||
if not self.acl.is_blocked(
|
||||
node.get_ip_address(),
|
||||
server.ip_address,
|
||||
ier_service,
|
||||
ier_port,
|
||||
node.get_ip_address(),
|
||||
server.ip_address,
|
||||
ier_service,
|
||||
ier_port,
|
||||
):
|
||||
possible_ier_destinations.append(server.node_id)
|
||||
if len(possible_ier_destinations) < 1:
|
||||
@@ -1347,37 +1345,42 @@ class Primaite(Env):
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source_node_id,
|
||||
node.node_id,
|
||||
ier_dest,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
||||
# TODO remove duplicate red pol for same targetted service - must take into account start step
|
||||
overwhelm_pol = red_pol
|
||||
overwhelm_pol.id = str(uuid.uuid4())
|
||||
overwhelm_pol.end_step = self.episode_steps
|
||||
|
||||
o_pol_id = str(3000 + n)
|
||||
o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
|
||||
o_pol_end_step = (
|
||||
self.episode_steps
|
||||
) # Can become compromised at any timestep after start
|
||||
o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
|
||||
o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
|
||||
o_pol_service_name = (
|
||||
ier_protocol # Same protocol/service as the IER uses to attack
|
||||
)
|
||||
o_pol_new_state = SoftwareState["OVERWHELMED"]
|
||||
o_pol_entry_node = False # Assumes servers are not entry nodes
|
||||
|
||||
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
||||
# # TODO remove duplicate red pol for same targetted service - must take into account start step
|
||||
#
|
||||
o_pol_id = str(uuid.uuid4())
|
||||
# o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
|
||||
# o_pol_end_step = (
|
||||
# self.episode_steps
|
||||
# ) # Can become compromised at any timestep after start
|
||||
# o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
|
||||
# o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
|
||||
# o_pol_service_name = (
|
||||
# ier_protocol # Same protocol/service as the IER uses to attack
|
||||
# )
|
||||
# o_pol_new_state = SoftwareState["OVERWHELMED"]
|
||||
# o_pol_entry_node = False # Assumes servers are not entry nodes
|
||||
o_red_pol = NodeStateInstructionRed(
|
||||
_id,
|
||||
_start_step,
|
||||
_end_step,
|
||||
_target_node_id,
|
||||
_pol_initiator,
|
||||
_pol_type,
|
||||
pol_protocol,
|
||||
_pol_state,
|
||||
_pol_source_node_id,
|
||||
_pol_source_node_service,
|
||||
_pol_source_node_service_state,
|
||||
_id=o_pol_id,
|
||||
_start_step=ier_start_step,
|
||||
_end_step=self.episode_steps,
|
||||
_target_node_id=ier_dest,
|
||||
_pol_initiator="DIRECT",
|
||||
_pol_type=NodePOLType["SERVICE"],
|
||||
pol_protocol=ier_protocol,
|
||||
_pol_state=SoftwareState.OVERWHELMED,
|
||||
_pol_source_node_id=source_node.node_id,
|
||||
_pol_source_node_service=source_node_service.name,
|
||||
_pol_source_node_service_state=source_node_service.software_state
|
||||
)
|
||||
self.red_node_pol[o_pol_id] = o_red_pol
|
||||
|
||||
Reference in New Issue
Block a user