Files
PrimAITE/src/primaite/environment/reward.py
Chris McCarthy 32a4d9e459 #1355 - Carried out full renaming in node.py, active_node.py, passive_node.py, and service_node.py to make params and variable names explicit.
- Made the same renaming in the yaml laydown config files.
- Added Type hints wherever I've been.
- Added a custom NodeType in custom_typing.py to encompass the Union of ActiveNode, PassiveNode, ServiceNode.
2023-05-25 21:03:11 +01:00

359 lines
17 KiB
Python

# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Implements reward function."""
from typing import Dict
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
red_iers,
step_count,
config_values,
):
"""
Compares the states of the initial and final nodes/links to get a reward.
Args:
initial_nodes: The nodes before red and blue agents take effect
final_nodes: The nodes after red and blue agents take effect
reference_nodes: The nodes if there had been no red or blue effect
green_iers: The green IERs (should be running)
red_iers: Should be stopeed (ideally) by the blue agent
step_count: current step
config_values: Config values
"""
reward_value = 0
# For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items():
initial_node = initial_nodes[node_key]
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(
final_node, initial_node, reference_node, config_values
)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(
final_node, initial_node, reference_node, config_values
)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(
final_node, initial_node, reference_node, config_values
)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if ier_value.get_is_running():
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
for ier_key, ier_value in green_iers.items():
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running():
reward_value += (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
)
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the hardware state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
final_node_operating_state = final_node.hardware_state
initial_node_operating_state = initial_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
if final_node_operating_state == reference_node_operating_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference (current) state of node (i.e. at every step)
if initial_node_operating_state == HardwareState.ON:
if reference_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_on
elif reference_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_on
else:
pass
elif initial_node_operating_state == HardwareState.OFF:
if reference_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_off
elif reference_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_off
else:
pass
elif initial_node_operating_state == HardwareState.RESETTING:
if reference_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_resetting
elif reference_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_resetting
elif reference_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting
else:
pass
else:
pass
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the Software State of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
final_node_os_state = final_node.software_state
initial_node_os_state = initial_node.software_state
reference_node_os_state = reference_node.software_state
if final_node_os_state == reference_node_os_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference (current) state of node (i.e. at every step)
if initial_node_os_state == SoftwareState.GOOD:
if reference_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif reference_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
else:
pass
elif initial_node_os_state == SoftwareState.PATCHING:
if reference_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif reference_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif reference_node_os_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif initial_node_os_state == SoftwareState.COMPROMISED:
if reference_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif reference_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif reference_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised
else:
pass
else:
pass
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the service state(s) of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
final_node_services: Dict[str, Service] = final_node.services
initial_node_services: Dict[str, Service] = initial_node.services
reference_node_services: Dict[str, Service] = reference_node.services
for service_key, final_service in final_node_services.items():
reference_service = reference_node_services[service_key]
initial_service = initial_node_services[service_key]
if final_service.software_state == reference_service.software_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference state of node (i.e. at every step)
if initial_service.software_state == SoftwareState.GOOD:
if reference_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif reference_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
elif reference_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_good
else:
pass
elif initial_service.software_state == SoftwareState.PATCHING:
if reference_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif reference_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif reference_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching
elif reference_service.software_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif initial_service.software_state == SoftwareState.COMPROMISED:
if reference_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif reference_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif reference_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised
elif reference_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised
else:
pass
elif initial_service.software_state == SoftwareState.OVERWHELMED:
if reference_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_overwhelmed
elif reference_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_overwhelmed
elif reference_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed
elif reference_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed
else:
pass
else:
pass
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the file system state of a node.
Args:
final_node: The node after red and blue agents take effect
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
"""
score = 0
final_node_file_system_state = final_node.file_system_state_actual
initial_node_file_system_state = initial_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual
final_node_scanning_state = final_node.file_system_scanning
reference_node_scanning_state = reference_node.file_system_scanning
# File System State
if final_node_file_system_state == reference_node_file_system_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference state of node (i.e. at every step)
if initial_node_file_system_state == FileSystemState.GOOD:
if reference_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_good
elif reference_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_good
elif reference_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_good
elif reference_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_good
else:
pass
elif initial_node_file_system_state == FileSystemState.REPAIRING:
if reference_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_repairing
elif reference_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_repairing
elif reference_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_repairing
elif reference_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_repairing
elif reference_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing
else:
pass
elif initial_node_file_system_state == FileSystemState.RESTORING:
if reference_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_restoring
elif reference_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_restoring
elif reference_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_restoring
elif reference_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_restoring
elif reference_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring
else:
pass
elif initial_node_file_system_state == FileSystemState.CORRUPT:
if reference_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt
else:
pass
elif initial_node_file_system_state == FileSystemState.DESTROYED:
if reference_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed
else:
pass
else:
pass
# Scanning State
if final_node_scanning_state == reference_node_scanning_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# We're scanning the file system which incurs a penalty (as it slows down systems)
score += config_values.scanning
return score