diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 007dcdc8..a620f9b3 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -93,7 +93,6 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ """ score = 0 final_node_operating_state = final_node.hardware_state - initial_node_operating_state = initial_node.hardware_state reference_node_operating_state = reference_node.hardware_state if final_node_operating_state == reference_node_operating_state: @@ -101,27 +100,27 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference (current) state of node (i.e. at every step) - if initial_node_operating_state == HardwareState.ON: - if reference_node_operating_state == HardwareState.OFF: + # Need to compare reference and final (current) state of node (i.e. at every step) + if reference_node_operating_state == HardwareState.ON: + if final_node_operating_state == HardwareState.OFF: score += config_values.off_should_be_on - elif reference_node_operating_state == HardwareState.RESETTING: + elif final_node_operating_state == HardwareState.RESETTING: score += config_values.resetting_should_be_on else: pass - elif initial_node_operating_state == HardwareState.OFF: - if reference_node_operating_state == HardwareState.ON: + elif reference_node_operating_state == HardwareState.OFF: + if final_node_operating_state == HardwareState.ON: score += config_values.on_should_be_off - elif reference_node_operating_state == HardwareState.RESETTING: + elif final_node_operating_state == HardwareState.RESETTING: score += config_values.resetting_should_be_off else: pass - elif initial_node_operating_state == HardwareState.RESETTING: - if reference_node_operating_state == HardwareState.ON: + elif reference_node_operating_state == HardwareState.RESETTING: + if final_node_operating_state == HardwareState.ON: score += config_values.on_should_be_resetting - elif reference_node_operating_state == HardwareState.OFF: + elif final_node_operating_state == HardwareState.OFF: score += config_values.off_should_be_resetting - elif reference_node_operating_state == HardwareState.RESETTING: + elif final_node_operating_state == HardwareState.RESETTING: score += config_values.resetting else: pass @@ -143,7 +142,6 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) """ score = 0 final_node_os_state = final_node.software_state - initial_node_os_state = initial_node.software_state reference_node_os_state = reference_node.software_state if final_node_os_state == reference_node_os_state: @@ -151,29 +149,29 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference (current) state of node (i.e. at every step) - if initial_node_os_state == SoftwareState.GOOD: - if reference_node_os_state == SoftwareState.PATCHING: + # Need to compare reference and final (current) state of node (i.e. at every step) + if reference_node_os_state == SoftwareState.GOOD: + if final_node_os_state == SoftwareState.PATCHING: score += config_values.patching_should_be_good - elif reference_node_os_state == SoftwareState.COMPROMISED: + elif final_node_os_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_good else: pass - elif initial_node_os_state == SoftwareState.PATCHING: - if reference_node_os_state == SoftwareState.GOOD: + elif reference_node_os_state == SoftwareState.PATCHING: + if final_node_os_state == SoftwareState.GOOD: score += config_values.good_should_be_patching - elif reference_node_os_state == SoftwareState.COMPROMISED: + elif final_node_os_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_patching - elif reference_node_os_state == SoftwareState.PATCHING: + elif final_node_os_state == SoftwareState.PATCHING: score += config_values.patching else: pass - elif initial_node_os_state == SoftwareState.COMPROMISED: - if reference_node_os_state == SoftwareState.GOOD: + elif reference_node_os_state == SoftwareState.COMPROMISED: + if final_node_os_state == SoftwareState.GOOD: score += config_values.good_should_be_compromised - elif reference_node_os_state == SoftwareState.PATCHING: + elif final_node_os_state == SoftwareState.PATCHING: score += config_values.patching_should_be_compromised - elif reference_node_os_state == SoftwareState.COMPROMISED: + elif final_node_os_state == SoftwareState.COMPROMISED: score += config_values.compromised else: pass @@ -195,58 +193,57 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va """ score = 0 final_node_services: Dict[str, Service] = final_node.services - initial_node_services: Dict[str, Service] = initial_node.services reference_node_services: Dict[str, Service] = reference_node.services for service_key, final_service in final_node_services.items(): reference_service = reference_node_services[service_key] - initial_service = initial_node_services[service_key] + final_service = final_node_services[service_key] if final_service.software_state == reference_service.software_state: # All is well - we're no different from the reference situation score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference state of node (i.e. at every step) - if initial_service.software_state == SoftwareState.GOOD: - if reference_service.software_state == SoftwareState.PATCHING: + # Need to compare reference and final state of node (i.e. at every step) + if reference_service.software_state == SoftwareState.GOOD: + if final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_good - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_good - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed_should_be_good else: pass - elif initial_service.software_state == SoftwareState.PATCHING: - if reference_service.software_state == SoftwareState.GOOD: + elif reference_service.software_state == SoftwareState.PATCHING: + if final_service.software_state == SoftwareState.GOOD: score += config_values.good_should_be_patching - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_patching - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed_should_be_patching - elif reference_service.software_state == SoftwareState.PATCHING: + elif final_service.software_state == SoftwareState.PATCHING: score += config_values.patching else: pass - elif initial_service.software_state == SoftwareState.COMPROMISED: - if reference_service.software_state == SoftwareState.GOOD: + elif reference_service.software_state == SoftwareState.COMPROMISED: + if final_service.software_state == SoftwareState.GOOD: score += config_values.good_should_be_compromised - elif reference_service.software_state == SoftwareState.PATCHING: + elif final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_compromised - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed_should_be_compromised else: pass - elif initial_service.software_state == SoftwareState.OVERWHELMED: - if reference_service.software_state == SoftwareState.GOOD: + elif reference_service.software_state == SoftwareState.OVERWHELMED: + if final_service.software_state == SoftwareState.GOOD: score += config_values.good_should_be_overwhelmed - elif reference_service.software_state == SoftwareState.PATCHING: + elif final_service.software_state == SoftwareState.PATCHING: score += config_values.patching_should_be_overwhelmed - elif reference_service.software_state == SoftwareState.COMPROMISED: + elif final_service.software_state == SoftwareState.COMPROMISED: score += config_values.compromised_should_be_overwhelmed - elif reference_service.software_state == SoftwareState.OVERWHELMED: + elif final_service.software_state == SoftwareState.OVERWHELMED: score += config_values.overwhelmed else: pass @@ -267,7 +264,6 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu """ score = 0 final_node_file_system_state = final_node.file_system_state_actual - initial_node_file_system_state = initial_node.file_system_state_actual reference_node_file_system_state = reference_node.file_system_state_actual final_node_scanning_state = final_node.file_system_scanning @@ -279,67 +275,67 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu score += config_values.all_ok else: # We're different from the reference situation - # Need to compare initial and reference state of node (i.e. at every step) - if initial_node_file_system_state == FileSystemState.GOOD: - if reference_node_file_system_state == FileSystemState.REPAIRING: + # Need to compare reference and final state of node (i.e. at every step) + if reference_node_file_system_state == FileSystemState.GOOD: + if final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_good - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_good - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_good - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_good else: pass - elif initial_node_file_system_state == FileSystemState.REPAIRING: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.REPAIRING: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_repairing - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_repairing - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_repairing - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_repairing - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing else: pass - elif initial_node_file_system_state == FileSystemState.RESTORING: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.RESTORING: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_restoring - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_restoring - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_restoring - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_restoring - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring else: pass - elif initial_node_file_system_state == FileSystemState.CORRUPT: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.CORRUPT: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed_should_be_corrupt - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt else: pass - elif initial_node_file_system_state == FileSystemState.DESTROYED: - if reference_node_file_system_state == FileSystemState.GOOD: + elif reference_node_file_system_state == FileSystemState.DESTROYED: + if final_node_file_system_state == FileSystemState.GOOD: score += config_values.good_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.REPAIRING: + elif final_node_file_system_state == FileSystemState.REPAIRING: score += config_values.repairing_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.RESTORING: + elif final_node_file_system_state == FileSystemState.RESTORING: score += config_values.restoring_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.CORRUPT: + elif final_node_file_system_state == FileSystemState.CORRUPT: score += config_values.corrupt_should_be_destroyed - elif reference_node_file_system_state == FileSystemState.DESTROYED: + elif final_node_file_system_state == FileSystemState.DESTROYED: score += config_values.destroyed else: pass diff --git a/tests/config/one_node_states_on_off_lay_down_config.yaml b/tests/config/one_node_states_on_off_lay_down_config.yaml index 355760bf..00f8016e 100644 --- a/tests/config/one_node_states_on_off_lay_down_config.yaml +++ b/tests/config/one_node_states_on_off_lay_down_config.yaml @@ -1,7 +1,7 @@ - itemType: ACTIONS type: NODE - itemType: STEPS - steps: 13 + steps: 15 - itemType: PORTS portsList: - port: '21' @@ -42,7 +42,7 @@ - itemType: RED_POL id: '2' startStep: 3 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: FILE @@ -66,7 +66,7 @@ - itemType: RED_POL id: '4' startStep: 6 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: OPERATING @@ -90,7 +90,7 @@ - itemType: RED_POL id: '6' startStep: 9 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: SERVICE @@ -114,7 +114,7 @@ - itemType: RED_POL id: '8' startStep: 12 - endStep: 13 + endStep: 15 targetNodeId: '1' initiator: DIRECT type: OS diff --git a/tests/test_reward.py b/tests/test_reward.py index 10dfb79c..4925a434 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -28,4 +28,4 @@ def test_rewards_are_being_penalised_at_each_step_function(): Average Reward: 2 (26 / 13) """ print("average reward", env.average_reward) - assert env.average_reward == 2.0 + assert env.average_reward == -8.0