1404 lines
54 KiB
Python
1404 lines
54 KiB
Python
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
|
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
|
import copy
|
|
import logging
|
|
import uuid as uuid
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from random import choice, randint, sample, uniform
|
|
from typing import Any, Dict, Final, List, Tuple, Union
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import yaml
|
|
from gym import Env, spaces
|
|
from matplotlib import pyplot as plt
|
|
|
|
from primaite import getLogger
|
|
from primaite.acl.access_control_list import AccessControlList
|
|
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
|
from primaite.common.custom_typing import NodeUnion
|
|
from primaite.common.enums import (
|
|
ActionType,
|
|
AgentFramework,
|
|
AgentIdentifier,
|
|
FileSystemState,
|
|
HardwareState,
|
|
NodePOLInitiator,
|
|
NodePOLType,
|
|
NodeType,
|
|
ObservationType,
|
|
Priority,
|
|
SessionType,
|
|
SoftwareState,
|
|
)
|
|
from primaite.common.service import Service
|
|
from primaite.config import training_config
|
|
from primaite.config.training_config import TrainingConfig
|
|
from primaite.environment.observations import ObservationsHandler
|
|
from primaite.environment.reward import calculate_reward_function
|
|
from primaite.links.link import Link
|
|
from primaite.nodes.active_node import ActiveNode
|
|
from primaite.nodes.node import Node
|
|
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
|
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
|
from primaite.nodes.passive_node import PassiveNode
|
|
from primaite.nodes.service_node import ServiceNode
|
|
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
|
from primaite.pol.ier import IER
|
|
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
|
from primaite.transactions.transaction import Transaction
|
|
from primaite.utils.session_output_writer import SessionOutputWriter
|
|
|
|
_LOGGER: Logger = getLogger(__name__)
|
|
|
|
|
|
class Primaite(Env):
|
|
"""PRIMmary AI Training Evironment (Primaite) class."""
|
|
|
|
# Action Space contants
|
|
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
|
|
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
|
|
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
|
|
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
|
|
|
def __init__(
|
|
self,
|
|
training_config_path: Union[str, Path],
|
|
lay_down_config_path: Union[str, Path],
|
|
session_path: Path,
|
|
timestamp_str: str,
|
|
) -> None:
|
|
"""
|
|
The Primaite constructor.
|
|
|
|
:param training_config_path: The training config filepath.
|
|
:param lay_down_config_path: The lay down config filepath.
|
|
:param session_path: The directory path the session is writing to.
|
|
:param timestamp_str: The session timestamp in the format: <yyyy-mm-dd>_<hh-mm- ss>.
|
|
"""
|
|
self.session_path: Final[Path] = session_path
|
|
self.timestamp_str: Final[str] = timestamp_str
|
|
self._training_config_path: Union[str, Path] = training_config_path
|
|
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
|
|
|
|
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
|
_LOGGER.info(f"Using: {str(self.training_config)}")
|
|
|
|
# Number of steps in an episode
|
|
self.episode_steps: int
|
|
if self.training_config.session_type == SessionType.TRAIN:
|
|
self.episode_steps = self.training_config.num_train_steps
|
|
elif self.training_config.session_type == SessionType.EVAL:
|
|
self.episode_steps = self.training_config.num_eval_steps
|
|
else:
|
|
self.episode_steps = self.training_config.num_train_steps
|
|
|
|
super(Primaite, self).__init__()
|
|
|
|
# The agent in use
|
|
self.agent_identifier: AgentIdentifier = self.training_config.agent_identifier
|
|
|
|
# Create a dictionary to hold all the nodes
|
|
self.nodes: Dict[str, NodeUnion] = {}
|
|
|
|
# Create a dictionary to hold a reference set of nodes
|
|
self.nodes_reference: Dict[str, NodeUnion] = {}
|
|
|
|
# Create a dictionary to hold all the links
|
|
self.links: Dict[str, Link] = {}
|
|
|
|
# Create a dictionary to hold a reference set of links
|
|
self.links_reference: Dict[str, Link] = {}
|
|
|
|
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
|
self.green_iers: Dict[str, IER] = {}
|
|
self.green_iers_reference: Dict[str, IER] = {}
|
|
|
|
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
|
self.node_pol: Dict[str, NodeStateInstructionGreen] = {}
|
|
|
|
# Create a dictionary to hold all the red agent IERs (this will come from an external source)
|
|
self.red_iers: Dict[str, IER] = {}
|
|
|
|
# Create a dictionary to hold all the red agent node PoLs (this will come from an external source)
|
|
self.red_node_pol: Dict[str, NodeStateInstructionRed] = {}
|
|
|
|
# Create the Access Control List
|
|
self.acl: AccessControlList = AccessControlList(
|
|
self.training_config.implicit_acl_rule,
|
|
self.training_config.max_number_acl_rules,
|
|
)
|
|
# Sets limit for number of ACL rules in environment
|
|
self.max_number_acl_rules: int = self.training_config.max_number_acl_rules
|
|
|
|
# Create a list of services (enums)
|
|
self.services_list: List[str] = []
|
|
|
|
# Create a list of ports
|
|
self.ports_list: List[str] = []
|
|
|
|
# Create graph (network)
|
|
self.network: nx.Graph = nx.MultiGraph()
|
|
|
|
# Create a graph (network) reference
|
|
self.network_reference: nx.Graph = nx.MultiGraph()
|
|
|
|
# Create step count
|
|
self.step_count: int = 0
|
|
|
|
self.total_step_count: int = 0
|
|
"""The total number of time steps completed."""
|
|
|
|
# Create step info dictionary
|
|
self.step_info: Dict[Any] = {}
|
|
|
|
# Total reward
|
|
self.total_reward: float = 0
|
|
|
|
# Average reward
|
|
self.average_reward: float = 0
|
|
|
|
# Episode count
|
|
self.episode_count: int = 0
|
|
|
|
# Number of nodes - gets a value by examining the nodes dictionary after it's been populated
|
|
self.num_nodes: int = 0
|
|
|
|
# Number of links - gets a value by examining the links dictionary after it's been populated
|
|
self.num_links: int = 0
|
|
|
|
# Number of services - gets a value when config is loaded
|
|
self.num_services: int = 0
|
|
|
|
# Number of ports - gets a value when config is loaded
|
|
self.num_ports: int = 0
|
|
|
|
# The action type
|
|
# TODO: confirm type
|
|
self.action_type: int = 0
|
|
|
|
# TODO fix up with TrainingConfig
|
|
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
|
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
|
|
if self.training_config.observation_space is not None:
|
|
self.obs_config = self.training_config.observation_space
|
|
|
|
# Observation Handler manages the user-configurable observation space.
|
|
# It will be initialised later.
|
|
self.obs_handler: ObservationsHandler
|
|
|
|
self._obs_space_description: List[str] = None
|
|
"The env observation space description for transactions writing"
|
|
|
|
# Open the config file and build the environment laydown
|
|
with open(self._lay_down_config_path, "r") as file:
|
|
# Open the config file and build the environment laydown
|
|
self.lay_down_config = yaml.safe_load(file)
|
|
self.load_lay_down_config()
|
|
|
|
# Store the node objects as node attributes
|
|
# (This is so we can access them as objects)
|
|
for node in self.network:
|
|
self.network.nodes[node]["self"] = node
|
|
|
|
for node in self.network_reference:
|
|
self.network_reference.nodes[node]["self"] = node
|
|
|
|
self.num_nodes = len(self.nodes)
|
|
self.num_links = len(self.links)
|
|
|
|
# Visualise in PNG
|
|
try:
|
|
plt.tight_layout()
|
|
nx.draw_networkx(self.network, with_labels=True)
|
|
|
|
file_path = session_path / f"network_{timestamp_str}.png"
|
|
plt.savefig(file_path, format="PNG")
|
|
plt.clf()
|
|
except Exception:
|
|
_LOGGER.error("Could not save network diagram", exc_info=True)
|
|
|
|
# Initiate observation space
|
|
self.observation_space: spaces.Space
|
|
self.env_obs: np.ndarray
|
|
self.observation_space, self.env_obs = self.init_observations()
|
|
|
|
# Define Action Space - depends on action space type (Node or ACL)
|
|
self.action_dict: Dict[int, List[int]]
|
|
self.action_space: spaces.Space
|
|
if self.training_config.action_type == ActionType.NODE:
|
|
_LOGGER.debug("Action space type NODE selected")
|
|
# Terms (for node action space):
|
|
# [0, num nodes] - node ID (0 = nothing, node ID)
|
|
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
|
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
|
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
|
self.action_dict = self.create_node_action_dict()
|
|
self.action_space = spaces.Discrete(len(self.action_dict))
|
|
elif self.training_config.action_type == ActionType.ACL:
|
|
_LOGGER.debug("Action space type ACL selected")
|
|
# Terms (for ACL action space):
|
|
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
|
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
|
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
|
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
|
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
|
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
|
self.action_dict = self.create_acl_action_dict()
|
|
self.action_space = spaces.Discrete(len(self.action_dict))
|
|
elif self.training_config.action_type == ActionType.ANY:
|
|
_LOGGER.debug("Action space type ANY selected - Node + ACL")
|
|
self.action_dict = self.create_node_and_acl_action_dict()
|
|
self.action_space = spaces.Discrete(len(self.action_dict))
|
|
else:
|
|
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
|
|
|
self.episode_av_reward_writer: SessionOutputWriter = SessionOutputWriter(
|
|
self, transaction_writer=False, learning_session=True
|
|
)
|
|
self.transaction_writer: SessionOutputWriter = SessionOutputWriter(
|
|
self, transaction_writer=True, learning_session=True
|
|
)
|
|
|
|
self.is_eval = False
|
|
|
|
@property
|
|
def actual_episode_count(self) -> int:
|
|
"""Shifts the episode_count by -1 for RLlib learning session."""
|
|
if self.training_config.agent_framework is AgentFramework.RLLIB and not self.is_eval:
|
|
return self.episode_count - 1
|
|
return self.episode_count
|
|
|
|
def set_as_eval(self) -> None:
|
|
"""Set the writers to write to eval directories."""
|
|
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
|
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
|
self.episode_count = 0
|
|
self.step_count = 0
|
|
self.total_step_count = 0
|
|
self.episode_steps = self.training_config.num_eval_steps
|
|
self.is_eval = True
|
|
|
|
def _write_av_reward_per_episode(self) -> None:
|
|
if self.actual_episode_count > 0:
|
|
csv_data = self.actual_episode_count, self.average_reward
|
|
self.episode_av_reward_writer.write(csv_data)
|
|
|
|
def reset(self) -> np.ndarray:
|
|
"""
|
|
AI Gym Reset function.
|
|
|
|
Returns:
|
|
Environment observation space (reset)
|
|
"""
|
|
self._write_av_reward_per_episode()
|
|
self.episode_count += 1
|
|
|
|
# Don't need to reset links, as they are cleared and recalculated every
|
|
# step
|
|
|
|
# Clear the ACL
|
|
self.init_acl()
|
|
|
|
# Reset the node statuses and recreate the ACL from config
|
|
# Does this for both live and reference nodes
|
|
self.reset_environment()
|
|
|
|
# Create a random red agent to use for this episode
|
|
if self.training_config.random_red_agent:
|
|
self._create_random_red_agent()
|
|
|
|
# Reset counters and totals
|
|
self.total_reward = 0.0
|
|
self.step_count = 0
|
|
self.average_reward = 0.0
|
|
|
|
# Update observations space and return
|
|
self.update_environent_obs()
|
|
|
|
return self.env_obs
|
|
|
|
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
|
"""
|
|
AI Gym Step function.
|
|
|
|
Args:
|
|
action: Action space from agent
|
|
|
|
Returns:
|
|
env_obs: Observation space
|
|
reward: Reward value for this step
|
|
done: Indicates episode is complete if True
|
|
step_info: Additional information relating to this step
|
|
"""
|
|
# TEMP
|
|
done = False
|
|
self.step_count += 1
|
|
self.total_step_count += 1
|
|
|
|
# Need to clear traffic on all links first
|
|
for link_key, link_value in self.links.items():
|
|
link_value.clear_traffic()
|
|
|
|
for link in self.links_reference.values():
|
|
link.clear_traffic()
|
|
|
|
# Create a Transaction (metric) object for this step
|
|
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
|
# Load the initial observation space into the transaction
|
|
transaction.obs_space = self.obs_handler._flat_observation
|
|
|
|
# Set the transaction obs space description
|
|
transaction.obs_space_description = self._obs_space_description
|
|
|
|
# Load the action space into the transaction
|
|
transaction.action_space = copy.deepcopy(action)
|
|
|
|
# 1. Implement Blue Action
|
|
self.interpret_action_and_apply(action)
|
|
# Take snapshots of nodes and links
|
|
self.nodes_post_blue = copy.deepcopy(self.nodes)
|
|
self.links_post_blue = copy.deepcopy(self.links)
|
|
|
|
# 2. Perform any time-based activities (e.g. a component moving from patching to good)
|
|
self.apply_time_based_updates()
|
|
|
|
# 3. Apply PoL
|
|
apply_node_pol(self.nodes, self.node_pol, self.step_count) # Node PoL
|
|
apply_iers(
|
|
self.network,
|
|
self.nodes,
|
|
self.links,
|
|
self.green_iers,
|
|
self.acl,
|
|
self.step_count,
|
|
) # Network PoL
|
|
# Take snapshots of nodes and links
|
|
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
|
self.links_post_pol = copy.deepcopy(self.links)
|
|
# Reference
|
|
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
|
apply_iers(
|
|
self.network_reference,
|
|
self.nodes_reference,
|
|
self.links_reference,
|
|
self.green_iers_reference,
|
|
self.acl,
|
|
self.step_count,
|
|
) # Network PoL
|
|
|
|
# 4. Implement Red Action
|
|
apply_red_agent_iers(
|
|
self.network,
|
|
self.nodes,
|
|
self.links,
|
|
self.red_iers,
|
|
self.acl,
|
|
self.step_count,
|
|
)
|
|
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
|
|
# Take snapshots of nodes and links
|
|
self.nodes_post_red = copy.deepcopy(self.nodes)
|
|
self.links_post_red = copy.deepcopy(self.links)
|
|
|
|
# 5. Calculate reward signal (for RL)
|
|
reward = calculate_reward_function(
|
|
self.nodes_post_pol,
|
|
self.nodes_post_red,
|
|
self.nodes_reference,
|
|
self.green_iers,
|
|
self.green_iers_reference,
|
|
self.red_iers,
|
|
self.step_count,
|
|
self.training_config,
|
|
)
|
|
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
|
|
self.total_reward += reward
|
|
if self.step_count == self.episode_steps:
|
|
self.average_reward = self.total_reward / self.step_count
|
|
if self.training_config.session_type is SessionType.EVAL:
|
|
# For evaluation, need to trigger the done value = True when
|
|
# step count is reached in order to prevent neverending episode
|
|
done = True
|
|
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
|
|
# Load the reward into the transaction
|
|
transaction.reward = reward
|
|
|
|
# 6. Output Verbose
|
|
# self.output_link_status()
|
|
|
|
# 7. Update env_obs
|
|
self.update_environent_obs()
|
|
|
|
# Write transaction to file
|
|
if self.actual_episode_count > 0:
|
|
self.transaction_writer.write(transaction)
|
|
|
|
# Return
|
|
return self.env_obs, reward, done, self.step_info
|
|
|
|
def close(self) -> None:
|
|
"""Override parent close and close writers."""
|
|
# Close files if last episode/step
|
|
# if self.can_finish:
|
|
super().close()
|
|
|
|
self.transaction_writer.close()
|
|
self.episode_av_reward_writer.close()
|
|
|
|
def init_acl(self) -> None:
|
|
"""Initialise the Access Control List."""
|
|
self.acl.remove_all_rules()
|
|
|
|
def output_link_status(self) -> None:
|
|
"""Output the link status of all links to the console."""
|
|
for link_key, link_value in self.links.items():
|
|
_LOGGER.debug("Link ID: " + link_value.get_id())
|
|
for protocol in link_value.protocol_list:
|
|
print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
|
|
|
def interpret_action_and_apply(self, _action: int) -> None:
|
|
"""
|
|
Applies agent actions to the nodes and Access Control List.
|
|
|
|
Args:
|
|
_action: The action space from the agent
|
|
"""
|
|
# At the moment, actions are only affecting nodes
|
|
if self.training_config.action_type == ActionType.NODE:
|
|
self.apply_actions_to_nodes(_action)
|
|
elif self.training_config.action_type == ActionType.ACL:
|
|
self.apply_actions_to_acl(_action)
|
|
elif len(self.action_dict[_action]) == 7: # ACL actions in multidiscrete form have len 7
|
|
self.apply_actions_to_acl(_action)
|
|
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
|
self.apply_actions_to_nodes(_action)
|
|
else:
|
|
logging.error("Invalid action type found")
|
|
|
|
def apply_actions_to_nodes(self, _action: int) -> None:
|
|
"""
|
|
Applies agent actions to the nodes.
|
|
|
|
Args:
|
|
_action: The action space from the agent
|
|
"""
|
|
readable_action = self.action_dict[_action]
|
|
node_id = readable_action[0]
|
|
node_property = readable_action[1]
|
|
property_action = readable_action[2]
|
|
service_index = readable_action[3]
|
|
|
|
# Check that the action is requesting a valid node
|
|
try:
|
|
node = self.nodes[str(node_id)]
|
|
except Exception:
|
|
return
|
|
|
|
if node_property == 0:
|
|
# This is the do nothing action
|
|
return
|
|
elif node_property == 1:
|
|
# This is an action on the node Hardware State
|
|
if property_action == 0:
|
|
# Do nothing
|
|
return
|
|
elif property_action == 1:
|
|
# Turn on (only applicable if it's OFF, not if it's patching)
|
|
if node.hardware_state == HardwareState.OFF:
|
|
node.turn_on()
|
|
elif property_action == 2:
|
|
# Turn off
|
|
node.turn_off()
|
|
elif property_action == 3:
|
|
# Reset (only applicable if it's ON)
|
|
if node.hardware_state == HardwareState.ON:
|
|
node.reset()
|
|
else:
|
|
return
|
|
elif node_property == 2:
|
|
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
|
# This is an action on the node Software State
|
|
if property_action == 0:
|
|
# Do nothing
|
|
return
|
|
elif property_action == 1:
|
|
# Patch (valid action if it's good or compromised)
|
|
node.software_state = SoftwareState.PATCHING
|
|
else:
|
|
# Node is not of Active or Service Type
|
|
return
|
|
elif node_property == 3:
|
|
# This is an action on a node Service State
|
|
if isinstance(node, ServiceNode):
|
|
# This is an action on a node Service State
|
|
if property_action == 0:
|
|
# Do nothing
|
|
return
|
|
elif property_action == 1:
|
|
# Patch (valid action if it's good or compromised)
|
|
node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING)
|
|
else:
|
|
# Node is not of Service Type
|
|
return
|
|
elif node_property == 4:
|
|
# This is an action on a node file system state
|
|
if isinstance(node, ActiveNode):
|
|
if property_action == 0:
|
|
# Do nothing
|
|
return
|
|
elif property_action == 1:
|
|
# Scan
|
|
node.start_file_system_scan()
|
|
elif property_action == 2:
|
|
# Repair
|
|
# You cannot repair a destroyed file system - it needs restoring
|
|
if node.file_system_state_actual != FileSystemState.DESTROYED:
|
|
node.set_file_system_state(FileSystemState.REPAIRING)
|
|
elif property_action == 3:
|
|
# Restore
|
|
node.set_file_system_state(FileSystemState.RESTORING)
|
|
else:
|
|
# Node is not of Active Type
|
|
return
|
|
else:
|
|
return
|
|
|
|
def apply_actions_to_acl(self, _action: int) -> None:
|
|
"""
|
|
Applies agent actions to the Access Control List [TO DO].
|
|
|
|
Args:
|
|
_action: The action space from the agent
|
|
"""
|
|
# Convert discrete value back to multidiscrete
|
|
readable_action = self.action_dict[_action]
|
|
|
|
action_decision = readable_action[0]
|
|
action_permission = readable_action[1]
|
|
action_source_ip = readable_action[2]
|
|
action_destination_ip = readable_action[3]
|
|
action_protocol = readable_action[4]
|
|
action_port = readable_action[5]
|
|
acl_rule_position = readable_action[6]
|
|
|
|
if action_decision == 0:
|
|
# It's decided to do nothing
|
|
return
|
|
else:
|
|
# It's decided to create a new ACL rule or remove an existing rule
|
|
# Permission value
|
|
if action_permission == 0:
|
|
acl_rule_permission = "DENY"
|
|
else:
|
|
acl_rule_permission = "ALLOW"
|
|
# Source IP value
|
|
if action_source_ip == 0:
|
|
acl_rule_source = "ANY"
|
|
else:
|
|
node = list(self.nodes.values())[action_source_ip - 1]
|
|
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
|
acl_rule_source = node.ip_address
|
|
else:
|
|
return
|
|
# Destination IP value
|
|
if action_destination_ip == 0:
|
|
acl_rule_destination = "ANY"
|
|
else:
|
|
node = list(self.nodes.values())[action_destination_ip - 1]
|
|
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
|
acl_rule_destination = node.ip_address
|
|
else:
|
|
return
|
|
# Protocol value
|
|
if action_protocol == 0:
|
|
acl_rule_protocol = "ANY"
|
|
else:
|
|
acl_rule_protocol = self.services_list[action_protocol - 1]
|
|
# Port value
|
|
if action_port == 0:
|
|
acl_rule_port = "ANY"
|
|
else:
|
|
acl_rule_port = self.ports_list[action_port - 1]
|
|
|
|
# Now add or remove
|
|
if action_decision == 1:
|
|
# Add the rule
|
|
self.acl.add_rule(
|
|
acl_rule_permission,
|
|
acl_rule_source,
|
|
acl_rule_destination,
|
|
acl_rule_protocol,
|
|
acl_rule_port,
|
|
acl_rule_position,
|
|
)
|
|
elif action_decision == 2:
|
|
# Remove the rule
|
|
self.acl.remove_rule(
|
|
acl_rule_permission,
|
|
acl_rule_source,
|
|
acl_rule_destination,
|
|
acl_rule_protocol,
|
|
acl_rule_port,
|
|
)
|
|
else:
|
|
return
|
|
|
|
def apply_time_based_updates(self) -> None:
|
|
"""
|
|
Updates anything that needs to count down and then change state.
|
|
|
|
e.g. reset / patching status
|
|
"""
|
|
for node_key, node in self.nodes.items():
|
|
if node.hardware_state == HardwareState.RESETTING:
|
|
node.update_resetting_status()
|
|
else:
|
|
pass
|
|
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
|
node.update_file_system_state()
|
|
if node.software_state == SoftwareState.PATCHING:
|
|
node.update_os_patching_status()
|
|
else:
|
|
pass
|
|
else:
|
|
pass
|
|
if isinstance(node, ServiceNode):
|
|
node.update_services_patching_status()
|
|
else:
|
|
pass
|
|
|
|
for node_key, node in self.nodes_reference.items():
|
|
if node.hardware_state == HardwareState.RESETTING:
|
|
node.update_resetting_status()
|
|
else:
|
|
pass
|
|
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
|
node.update_file_system_state()
|
|
if node.software_state == SoftwareState.PATCHING:
|
|
node.update_os_patching_status()
|
|
else:
|
|
pass
|
|
else:
|
|
pass
|
|
if isinstance(node, ServiceNode):
|
|
node.update_services_patching_status()
|
|
else:
|
|
pass
|
|
|
|
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
|
"""
|
|
Create the environment's observation handler.
|
|
|
|
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
|
:rtype: Tuple[spaces.Space, np.ndarray]
|
|
"""
|
|
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
|
|
|
if not self._obs_space_description:
|
|
self._obs_space_description = self.obs_handler.describe_structure()
|
|
|
|
return self.obs_handler.space, self.obs_handler.current_observation
|
|
|
|
def update_environent_obs(self) -> None:
|
|
"""Updates the observation space based on the node and link status."""
|
|
self.obs_handler.update_obs()
|
|
self.env_obs = self.obs_handler.current_observation
|
|
|
|
def load_lay_down_config(self) -> None:
|
|
"""Loads config data in order to build the environment configuration."""
|
|
for item in self.lay_down_config:
|
|
if item["item_type"] == "NODE":
|
|
# Create a node
|
|
self.create_node(item)
|
|
elif item["item_type"] == "LINK":
|
|
# Create a link
|
|
self.create_link(item)
|
|
elif item["item_type"] == "GREEN_IER":
|
|
# Create a Green IER
|
|
self.create_green_ier(item)
|
|
elif item["item_type"] == "GREEN_POL":
|
|
# Create a Green PoL
|
|
self.create_green_pol(item)
|
|
elif item["item_type"] == "RED_IER":
|
|
# Create a Red IER
|
|
self.create_red_ier(item)
|
|
elif item["item_type"] == "RED_POL":
|
|
# Create a Red PoL
|
|
self.create_red_pol(item)
|
|
elif item["item_type"] == "ACL_RULE":
|
|
# Create an ACL rule
|
|
self.create_acl_rule(item)
|
|
elif item["item_type"] == "SERVICES":
|
|
# Create the list of services
|
|
self.create_services_list(item)
|
|
elif item["item_type"] == "PORTS":
|
|
# Create the list of ports
|
|
self.create_ports_list(item)
|
|
else:
|
|
item_type = item["item_type"]
|
|
_LOGGER.error(f"Invalid item_type: {item_type}")
|
|
pass
|
|
|
|
_LOGGER.info("Environment configuration loaded")
|
|
print("Environment configuration loaded")
|
|
|
|
def create_node(self, item: Dict) -> None:
|
|
"""
|
|
Creates a node from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
# All nodes have these parameters
|
|
node_id = item["node_id"]
|
|
node_name = item["name"]
|
|
node_class = item["node_class"]
|
|
node_type = NodeType[item["node_type"]]
|
|
node_priority = Priority[item["priority"]]
|
|
node_hardware_state = HardwareState[item["hardware_state"]]
|
|
|
|
if node_class == "PASSIVE":
|
|
node = PassiveNode(
|
|
node_id,
|
|
node_name,
|
|
node_type,
|
|
node_priority,
|
|
node_hardware_state,
|
|
self.training_config,
|
|
)
|
|
elif node_class == "ACTIVE":
|
|
# Active nodes have IP address, Software State and file system state
|
|
node_ip_address = item["ip_address"]
|
|
node_software_state = SoftwareState[item["software_state"]]
|
|
node_file_system_state = FileSystemState[item["file_system_state"]]
|
|
node = ActiveNode(
|
|
node_id,
|
|
node_name,
|
|
node_type,
|
|
node_priority,
|
|
node_hardware_state,
|
|
node_ip_address,
|
|
node_software_state,
|
|
node_file_system_state,
|
|
self.training_config,
|
|
)
|
|
elif node_class == "SERVICE":
|
|
# Service nodes have IP address, Software State, file system state and list of services
|
|
node_ip_address = item["ip_address"]
|
|
node_software_state = SoftwareState[item["software_state"]]
|
|
node_file_system_state = FileSystemState[item["file_system_state"]]
|
|
node = ServiceNode(
|
|
node_id,
|
|
node_name,
|
|
node_type,
|
|
node_priority,
|
|
node_hardware_state,
|
|
node_ip_address,
|
|
node_software_state,
|
|
node_file_system_state,
|
|
self.training_config,
|
|
)
|
|
node_services = item["services"]
|
|
for service in node_services:
|
|
service_protocol = service["name"]
|
|
service_port = service["port"]
|
|
service_state = SoftwareState[service["state"]]
|
|
node.add_service(Service(service_protocol, service_port, service_state))
|
|
else:
|
|
# Bad formatting
|
|
pass
|
|
|
|
# Copy the node for the reference version
|
|
node_ref = copy.deepcopy(node)
|
|
|
|
# Add node to node dictionary
|
|
self.nodes[node_id] = node
|
|
|
|
# Add reference node to reference node dictionary
|
|
self.nodes_reference[node_id] = node_ref
|
|
|
|
# Add node to network
|
|
self.network.add_nodes_from([node])
|
|
|
|
# Add node to network (reference)
|
|
self.network_reference.add_nodes_from([node_ref])
|
|
|
|
def create_link(self, item: Dict) -> None:
|
|
"""
|
|
Creates a link from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
link_id = item["id"]
|
|
link_name = item["name"]
|
|
link_bandwidth = item["bandwidth"]
|
|
link_source = item["source"]
|
|
link_destination = item["destination"]
|
|
|
|
source_node: Node = self.nodes[link_source]
|
|
dest_node: Node = self.nodes[link_destination]
|
|
|
|
# Add link to network
|
|
self.network.add_edge(source_node, dest_node, id=link_name)
|
|
|
|
# Add link to link dictionary
|
|
self.links[link_name] = Link(
|
|
link_id,
|
|
link_bandwidth,
|
|
source_node.name,
|
|
dest_node.name,
|
|
self.services_list,
|
|
)
|
|
|
|
# Reference
|
|
source_node_ref: Node = self.nodes_reference[link_source]
|
|
dest_node_ref: Node = self.nodes_reference[link_destination]
|
|
|
|
# Add link to network (reference)
|
|
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
|
|
|
|
# Add link to link dictionary (reference)
|
|
self.links_reference[link_name] = Link(
|
|
link_id,
|
|
link_bandwidth,
|
|
source_node_ref.name,
|
|
dest_node_ref.name,
|
|
self.services_list,
|
|
)
|
|
|
|
def create_green_ier(self, item: Dict) -> None:
|
|
"""
|
|
Creates a green IER from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
ier_id = item["id"]
|
|
ier_start_step = item["start_step"]
|
|
ier_end_step = item["end_step"]
|
|
ier_load = item["load"]
|
|
ier_protocol = item["protocol"]
|
|
ier_port = item["port"]
|
|
ier_source = item["source"]
|
|
ier_destination = item["destination"]
|
|
ier_mission_criticality = item["mission_criticality"]
|
|
|
|
# Create IER and add to green IER dictionary
|
|
self.green_iers[ier_id] = IER(
|
|
ier_id,
|
|
ier_start_step,
|
|
ier_end_step,
|
|
ier_load,
|
|
ier_protocol,
|
|
ier_port,
|
|
ier_source,
|
|
ier_destination,
|
|
ier_mission_criticality,
|
|
)
|
|
self.green_iers_reference[ier_id] = IER(
|
|
ier_id,
|
|
ier_start_step,
|
|
ier_end_step,
|
|
ier_load,
|
|
ier_protocol,
|
|
ier_port,
|
|
ier_source,
|
|
ier_destination,
|
|
ier_mission_criticality,
|
|
)
|
|
|
|
def create_red_ier(self, item: Dict) -> None:
|
|
"""
|
|
Creates a red IER from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
ier_id = item["id"]
|
|
ier_start_step = item["start_step"]
|
|
ier_end_step = item["end_step"]
|
|
ier_load = item["load"]
|
|
ier_protocol = item["protocol"]
|
|
ier_port = item["port"]
|
|
ier_source = item["source"]
|
|
ier_destination = item["destination"]
|
|
ier_mission_criticality = item["mission_criticality"]
|
|
|
|
# Create IER and add to red IER dictionary
|
|
self.red_iers[ier_id] = IER(
|
|
ier_id,
|
|
ier_start_step,
|
|
ier_end_step,
|
|
ier_load,
|
|
ier_protocol,
|
|
ier_port,
|
|
ier_source,
|
|
ier_destination,
|
|
ier_mission_criticality,
|
|
)
|
|
|
|
def create_green_pol(self, item: Dict) -> None:
|
|
"""
|
|
Creates a green PoL object from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
pol_id = item["id"]
|
|
pol_start_step = item["start_step"]
|
|
pol_end_step = item["end_step"]
|
|
pol_node = item["nodeId"]
|
|
pol_type = NodePOLType[item["type"]]
|
|
|
|
# State depends on whether this is Operating, Software, file system or Service PoL type
|
|
if pol_type == NodePOLType.OPERATING:
|
|
pol_state = HardwareState[item["state"]]
|
|
pol_protocol = ""
|
|
elif pol_type == NodePOLType.FILE:
|
|
pol_state = FileSystemState[item["state"]]
|
|
pol_protocol = ""
|
|
else:
|
|
pol_protocol = item["protocol"]
|
|
pol_state = SoftwareState[item["state"]]
|
|
|
|
self.node_pol[pol_id] = NodeStateInstructionGreen(
|
|
pol_id,
|
|
pol_start_step,
|
|
pol_end_step,
|
|
pol_node,
|
|
pol_type,
|
|
pol_protocol,
|
|
pol_state,
|
|
)
|
|
|
|
def create_red_pol(self, item: Dict) -> None:
|
|
"""
|
|
Creates a red PoL object from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
pol_id = item["id"]
|
|
pol_start_step = item["start_step"]
|
|
pol_end_step = item["end_step"]
|
|
pol_target_node_id = item["targetNodeId"]
|
|
pol_initiator = NodePOLInitiator[item["initiator"]]
|
|
pol_type = NodePOLType[item["type"]]
|
|
pol_protocol = item["protocol"]
|
|
|
|
# State depends on whether this is Operating, Software, file system or Service PoL type
|
|
if pol_type == NodePOLType.OPERATING:
|
|
pol_state = HardwareState[item["state"]]
|
|
elif pol_type == NodePOLType.FILE:
|
|
pol_state = FileSystemState[item["state"]]
|
|
else:
|
|
pol_state = SoftwareState[item["state"]]
|
|
|
|
pol_source_node_id = item["sourceNodeId"]
|
|
pol_source_node_service = item["sourceNodeService"]
|
|
pol_source_node_service_state = item["sourceNodeServiceState"]
|
|
|
|
self.red_node_pol[pol_id] = NodeStateInstructionRed(
|
|
pol_id,
|
|
pol_start_step,
|
|
pol_end_step,
|
|
pol_target_node_id,
|
|
pol_initiator,
|
|
pol_type,
|
|
pol_protocol,
|
|
pol_state,
|
|
pol_source_node_id,
|
|
pol_source_node_service,
|
|
pol_source_node_service_state,
|
|
)
|
|
|
|
def create_acl_rule(self, item: Dict) -> None:
|
|
"""
|
|
Creates an ACL rule from config data.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
acl_rule_permission = item["permission"]
|
|
acl_rule_source = item["source"]
|
|
acl_rule_destination = item["destination"]
|
|
acl_rule_protocol = item["protocol"]
|
|
acl_rule_port = item["port"]
|
|
acl_rule_position = item["position"]
|
|
|
|
self.acl.add_rule(
|
|
acl_rule_permission,
|
|
acl_rule_source,
|
|
acl_rule_destination,
|
|
acl_rule_protocol,
|
|
acl_rule_port,
|
|
acl_rule_position,
|
|
)
|
|
|
|
# TODO: confirm typehint using runtime
|
|
def create_services_list(self, services: Dict) -> None:
|
|
"""
|
|
Creates a list of services (enum) from config data.
|
|
|
|
Args:
|
|
item: A config data item representing the services
|
|
"""
|
|
service_list = services["service_list"]
|
|
|
|
for service in service_list:
|
|
service_name = service["name"]
|
|
self.services_list.append(service_name)
|
|
|
|
# Set the number of services
|
|
self.num_services = len(self.services_list)
|
|
|
|
def create_ports_list(self, ports: Dict) -> None:
|
|
"""
|
|
Creates a list of ports from config data.
|
|
|
|
Args:
|
|
item: A config data item representing the ports
|
|
"""
|
|
ports_list = ports["ports_list"]
|
|
|
|
for port in ports_list:
|
|
port_value = port["port"]
|
|
self.ports_list.append(port_value)
|
|
|
|
# Set the number of ports
|
|
self.num_ports = len(self.ports_list)
|
|
|
|
# TODO: this is not used anymore, write a ticket to delete it
|
|
def get_observation_info(self, observation_info: Dict) -> None:
|
|
"""
|
|
Extracts observation_info.
|
|
|
|
:param observation_info: Config item that defines which type of observation space to use
|
|
:type observation_info: str
|
|
"""
|
|
self.observation_type = ObservationType[observation_info["type"]]
|
|
|
|
# TODO: this is not used anymore, write a ticket to delete it.
|
|
def get_action_info(self, action_info: Dict) -> None:
|
|
"""
|
|
Extracts action_info.
|
|
|
|
Args:
|
|
item: A config data item representing action info
|
|
"""
|
|
self.action_type = ActionType[action_info["type"]]
|
|
|
|
def save_obs_config(self, obs_config: dict) -> None:
|
|
"""
|
|
Cache the config for the observation space.
|
|
|
|
This is necessary as the observation space can't be built while reading the config,
|
|
it must be done after all the nodes, links, and services have been initialised.
|
|
|
|
:param obs_config: Parsed config relating to the observation space. The format is described in
|
|
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
|
|
:type obs_config: dict
|
|
"""
|
|
self.obs_config = obs_config
|
|
|
|
def reset_environment(self) -> None:
|
|
"""
|
|
Resets environment.
|
|
|
|
Uses config data config data in order to build the environment configuration.
|
|
"""
|
|
for item in self.lay_down_config:
|
|
if item["item_type"] == "NODE":
|
|
# Reset a node's state (normal and reference)
|
|
self.reset_node(item)
|
|
elif item["item_type"] == "ACL_RULE":
|
|
# Create an ACL rule (these are cleared on reset, so just need to recreate them)
|
|
self.create_acl_rule(item)
|
|
else:
|
|
# Do nothing (bad formatting or not relevant to reset)
|
|
pass
|
|
|
|
# Reset the IER status so they are not running initially
|
|
# Green IERs
|
|
for ier_key, ier_value in self.green_iers.items():
|
|
ier_value.set_is_running(False)
|
|
# Red IERs
|
|
for ier_key, ier_value in self.red_iers.items():
|
|
ier_value.set_is_running(False)
|
|
|
|
def reset_node(self, item: Dict) -> None:
|
|
"""
|
|
Resets the statuses of a node.
|
|
|
|
Args:
|
|
item: A config data item
|
|
"""
|
|
# All nodes have these parameters
|
|
node_id = item["node_id"]
|
|
node_class = item["node_class"]
|
|
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
|
|
|
|
node: NodeUnion = self.nodes[node_id]
|
|
node_ref = self.nodes_reference[node_id]
|
|
|
|
# Reset the hardware state (common for all node types)
|
|
node.hardware_state = node_hardware_state
|
|
node_ref.hardware_state = node_hardware_state
|
|
|
|
if node_class == "ACTIVE":
|
|
# Active nodes have Software State
|
|
node_software_state = SoftwareState[item["software_state"]]
|
|
node_file_system_state = FileSystemState[item["file_system_state"]]
|
|
node.software_state = node_software_state
|
|
node_ref.software_state = node_software_state
|
|
node.set_file_system_state(node_file_system_state)
|
|
node_ref.set_file_system_state(node_file_system_state)
|
|
elif node_class == "SERVICE":
|
|
# Service nodes have Software State and list of services
|
|
node_software_state = SoftwareState[item["software_state"]]
|
|
node_file_system_state = FileSystemState[item["file_system_state"]]
|
|
node.software_state = node_software_state
|
|
node_ref.software_state = node_software_state
|
|
node.set_file_system_state(node_file_system_state)
|
|
node_ref.set_file_system_state(node_file_system_state)
|
|
# Update service states
|
|
node_services = item["services"]
|
|
for service in node_services:
|
|
service_protocol = service["name"]
|
|
service_state = SoftwareState[service["state"]]
|
|
# Update node service state
|
|
node.set_service_state(service_protocol, service_state)
|
|
# Update reference node service state
|
|
node_ref.set_service_state(service_protocol, service_state)
|
|
else:
|
|
# Bad formatting
|
|
pass
|
|
|
|
def create_node_action_dict(self) -> Dict[int, List[int]]:
|
|
"""
|
|
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
|
|
|
|
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
|
|
|
|
example return:
|
|
{0: [1, 0, 0, 0],
|
|
1: [1, 1, 1, 0],
|
|
2: [1, 1, 2, 0],
|
|
3: [1, 1, 3, 0],
|
|
4: [1, 2, 1, 0],
|
|
5: [1, 3, 1, 0],
|
|
...
|
|
}
|
|
"""
|
|
# Terms (for node action space):
|
|
# [0, num nodes] - node ID (0 = nothing, node ID)
|
|
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
|
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
|
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
|
# reserve 0 action to be a nothing action
|
|
actions = {0: [1, 0, 0, 0]}
|
|
action_key = 1
|
|
for node in range(1, self.num_nodes + 1):
|
|
# 4 node properties (NONE, OPERATING, OS, SERVICE)
|
|
for node_property in range(4):
|
|
# Node Actions either:
|
|
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
|
|
# Use MAX to ensure we get them all
|
|
for node_action in range(4):
|
|
for service_state in range(self.num_services):
|
|
action = [node, node_property, node_action, service_state]
|
|
# check to see if it's a nothing action (has no effect)
|
|
if is_valid_node_action(action):
|
|
actions[action_key] = action
|
|
action_key += 1
|
|
|
|
return actions
|
|
|
|
def create_acl_action_dict(self) -> Dict[int, List[int]]:
|
|
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
|
|
# Terms (for ACL action space):
|
|
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
|
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
|
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
|
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
|
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
|
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
|
# [0, max acl rules - 1] - Position (0 = first index, then 1 -> x index resolving to acl rule in acl list)
|
|
# reserve 0 action to be a nothing action
|
|
actions = {0: [0, 0, 0, 0, 0, 0, 0]}
|
|
|
|
action_key = 1
|
|
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
|
|
for action_decision in range(3):
|
|
# 2 possible action permissions 0 = DENY, 1 = CREATE
|
|
for action_permission in range(2):
|
|
# Number of nodes + 1 (for any)
|
|
for source_ip in range(self.num_nodes + 1):
|
|
for dest_ip in range(self.num_nodes + 1):
|
|
for protocol in range(self.num_services + 1):
|
|
for port in range(self.num_ports + 1):
|
|
for position in range(self.max_number_acl_rules - 1):
|
|
action = [
|
|
action_decision,
|
|
action_permission,
|
|
source_ip,
|
|
dest_ip,
|
|
protocol,
|
|
port,
|
|
position,
|
|
]
|
|
# Check to see if it is an action we want to include as possible
|
|
# i.e. not a nothing action
|
|
if is_valid_acl_action_extra(action):
|
|
actions[action_key] = action
|
|
action_key += 1
|
|
|
|
return actions
|
|
|
|
def create_node_and_acl_action_dict(self) -> Dict[int, List[int]]:
|
|
"""
|
|
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
|
|
|
|
The dictionary contains actions of both Node and ACL action types.
|
|
"""
|
|
node_action_dict = self.create_node_action_dict()
|
|
acl_action_dict = self.create_acl_action_dict()
|
|
|
|
# Change node keys to not overlap with acl keys
|
|
# Only 1 nothing action (key 0) is required, remove the other
|
|
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
|
|
|
|
# Combine the Node dict and ACL dict
|
|
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
|
return combined_action_dict
|
|
|
|
def _create_random_red_agent(self) -> None:
|
|
"""Decide on random red agent for the episode to be called in env.reset()."""
|
|
# Reset the current red iers and red node pol
|
|
self.red_iers = {}
|
|
self.red_node_pol = {}
|
|
|
|
# Decide how many nodes become compromised
|
|
node_list = list(self.nodes.values())
|
|
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
|
max_num_nodes_compromised = len(computers) # only computers can become compromised
|
|
# random select between 1 and max_num_nodes_compromised
|
|
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
|
|
|
|
# Decide which of the nodes to compromise
|
|
nodes_to_be_compromised = sample(computers, num_nodes_to_compromise)
|
|
|
|
# choose a random compromise node to be source of attacks
|
|
source_node = choice(nodes_to_be_compromised)
|
|
|
|
# For each of the nodes to be compromised decide which step they become compromised
|
|
max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode
|
|
|
|
# Bandwidth for all links
|
|
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
|
|
|
if len(bandwidths) < 1:
|
|
msg = "Random red agent cannot be used on a network without any links"
|
|
_LOGGER.error(msg)
|
|
raise Exception(msg)
|
|
|
|
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
|
|
|
|
for n, node in enumerate(nodes_to_be_compromised):
|
|
# 1: Use Node PoL to set node to compromised
|
|
|
|
_id = str(uuid.uuid4())
|
|
_start_step = randint(2, max_step_compromised + 1) # step compromised
|
|
pol_service_name = choice(list(node.services.keys()))
|
|
|
|
source_node_service = choice(list(source_node.services.values()))
|
|
|
|
red_pol = NodeStateInstructionRed(
|
|
_id=_id,
|
|
_start_step=_start_step,
|
|
_end_step=_start_step, # only run for 1 step
|
|
_target_node_id=node.node_id,
|
|
_pol_initiator="DIRECT",
|
|
_pol_type=NodePOLType["SERVICE"],
|
|
pol_protocol=pol_service_name,
|
|
_pol_state=SoftwareState.COMPROMISED,
|
|
_pol_source_node_id=source_node.node_id,
|
|
_pol_source_node_service=source_node_service.name,
|
|
_pol_source_node_service_state=source_node_service.software_state,
|
|
)
|
|
|
|
self.red_node_pol[_id] = red_pol
|
|
|
|
# 2: Launch the attack from compromised node - set the IER
|
|
|
|
ier_id = str(uuid.uuid4())
|
|
# Launch the attack after node is compromised, and not right at the end of the episode
|
|
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
|
|
ier_end_step = self.episode_steps
|
|
|
|
# Randomise the load, as a percentage of a random link bandwith
|
|
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
|
|
ier_protocol = pol_service_name # Same protocol as compromised node
|
|
ier_service = node.services[pol_service_name]
|
|
ier_port = ier_service.port
|
|
ier_mission_criticality = 0 # Red IER will never be important to green agent success
|
|
# We choose a node to attack based on the first that applies:
|
|
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
|
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
|
# c. Attack a random server
|
|
possible_ier_destinations = [
|
|
ier.get_dest_node_id()
|
|
for ier in list(self.green_iers.values())
|
|
if ier.get_source_node_id() == node.node_id
|
|
]
|
|
if len(possible_ier_destinations) < 1:
|
|
for server in servers:
|
|
if not self.acl.is_blocked(
|
|
node.ip_address,
|
|
server.ip_address,
|
|
ier_service,
|
|
ier_port,
|
|
):
|
|
possible_ier_destinations.append(server.node_id)
|
|
if len(possible_ier_destinations) < 1:
|
|
# If still none found choose from all servers
|
|
possible_ier_destinations = [server.node_id for server in servers]
|
|
ier_dest = choice(possible_ier_destinations)
|
|
self.red_iers[ier_id] = IER(
|
|
ier_id,
|
|
ier_start_step,
|
|
ier_end_step,
|
|
ier_load,
|
|
ier_protocol,
|
|
ier_port,
|
|
node.node_id,
|
|
ier_dest,
|
|
ier_mission_criticality,
|
|
)
|
|
|
|
overwhelm_pol = red_pol
|
|
overwhelm_pol.id = str(uuid.uuid4())
|
|
overwhelm_pol.end_step = self.episode_steps
|
|
|
|
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
|
# # TODO remove duplicate red pol for same targetted service - must take into account start step
|
|
|
|
o_pol_id = str(uuid.uuid4())
|
|
o_red_pol = NodeStateInstructionRed(
|
|
_id=o_pol_id,
|
|
_start_step=ier_start_step,
|
|
_end_step=self.episode_steps,
|
|
_target_node_id=ier_dest,
|
|
_pol_initiator="DIRECT",
|
|
_pol_type=NodePOLType["SERVICE"],
|
|
pol_protocol=ier_protocol,
|
|
_pol_state=SoftwareState.OVERWHELMED,
|
|
_pol_source_node_id=source_node.node_id,
|
|
_pol_source_node_service=source_node_service.name,
|
|
_pol_source_node_service_state=source_node_service.software_state,
|
|
)
|
|
self.red_node_pol[o_pol_id] = o_red_pol
|