Files
PrimAITE/src/primaite/environment/reward.py
2023-06-27 12:56:15 +01:00

381 lines
17 KiB
Python

# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict
from primaite import getLogger
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
):
"""
Compares the states of the initial and final nodes/links to get a reward.
Args:
initial_nodes: The nodes before red and blue agents take effect
final_nodes: The nodes after red and blue agents take effect
reference_nodes: The nodes if there had been no red or blue effect
green_iers: The green IERs (should be running)
red_iers: Should be stopeed (ideally) by the blue agent
step_count: current step
config_values: Config values
"""
reward_value = 0
# For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items():
initial_node = initial_nodes[node_key]
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(
final_node, initial_node, reference_node, config_values
)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(
final_node, initial_node, reference_node, config_values
)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(
final_node, initial_node, reference_node, config_values
)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if ier_value.get_is_running():
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
reference_blocked = reference_ier.get_is_running()
live_blocked = ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality()
)
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the hardware state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
final_node_operating_state = final_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
if final_node_operating_state == reference_node_operating_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_operating_state == HardwareState.ON:
if final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_on
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_on
else:
pass
elif reference_node_operating_state == HardwareState.OFF:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_off
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_off
else:
pass
elif reference_node_operating_state == HardwareState.RESETTING:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_resetting
elif final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_resetting
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting
else:
pass
else:
pass
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the Software State of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
final_node_os_state = final_node.software_state
reference_node_os_state = reference_node.software_state
if final_node_os_state == reference_node_os_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_os_state == SoftwareState.GOOD:
if final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
else:
pass
elif reference_node_os_state == SoftwareState.PATCHING:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif reference_node_os_state == SoftwareState.COMPROMISED:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised
else:
pass
else:
pass
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the service state(s) of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
final_node_services: Dict[str, Service] = final_node.services
reference_node_services: Dict[str, Service] = reference_node.services
for service_key, final_service in final_node_services.items():
reference_service = reference_node_services[service_key]
final_service = final_node_services[service_key]
if final_service.software_state == reference_service.software_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final state of node (i.e. at every step)
if reference_service.software_state == SoftwareState.GOOD:
if final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_good
else:
pass
elif reference_service.software_state == SoftwareState.PATCHING:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif reference_service.software_state == SoftwareState.COMPROMISED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised
else:
pass
elif reference_service.software_state == SoftwareState.OVERWHELMED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_overwhelmed
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_overwhelmed
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed
else:
pass
else:
pass
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the file system state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
"""
score = 0
final_node_file_system_state = final_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual
final_node_scanning_state = final_node.file_system_scanning
reference_node_scanning_state = reference_node.file_system_scanning
# File System State
if final_node_file_system_state == reference_node_file_system_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare reference and final state of node (i.e. at every step)
if reference_node_file_system_state == FileSystemState.GOOD:
if final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_good
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_good
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_good
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_good
else:
pass
elif reference_node_file_system_state == FileSystemState.REPAIRING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_repairing
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_repairing
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_repairing
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_repairing
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing
else:
pass
elif reference_node_file_system_state == FileSystemState.RESTORING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_restoring
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_restoring
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_restoring
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_restoring
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring
else:
pass
elif reference_node_file_system_state == FileSystemState.CORRUPT:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_corrupt
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_corrupt
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_corrupt
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_corrupt
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt
else:
pass
elif reference_node_file_system_state == FileSystemState.DESTROYED:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_destroyed
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_destroyed
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_destroyed
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_destroyed
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed
else:
pass
else:
pass
# Scanning State
if final_node_scanning_state == reference_node_scanning_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# We're scanning the file system which incurs a penalty (as it slows down systems)
score += config_values.scanning
return score