Merged PR 56: #902 - Fix the reward functionality for node operating system state

#902 - replaced 'final_node_<placeholder>' with 'reference_node_<placeholder>' in methods for scoring of os_state, file_system_state, service state and operating state. This fixed the reward function so it is checked at each step for node operating system state, operating state, file system state and service state.
- Added unit tests.

Related work items: #902
This commit is contained in:
Christopher McCarthy
2023-05-25 15:28:19 +00:00
9 changed files with 526 additions and 75 deletions

View File

@@ -37,6 +37,8 @@ from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
class Primaite(Env):
"""PRIMmary AI Training Evironment (Primaite) class."""
@@ -145,14 +147,12 @@ class Primaite(Env):
# Open the config file and build the environment laydown
try:
self.config_file = open(
"config/" + self.config_values.config_filename_use_case, "r"
)
self.config_file = open(self.config_values.config_filename_use_case, "r")
self.config_data = yaml.safe_load(self.config_file)
self.load_config()
except Exception:
logging.error("Could not load the environment configuration")
logging.error("Exception occured", exc_info=True)
_LOGGER.error("Could not load the environment configuration")
_LOGGER.error("Exception occured", exc_info=True)
# Store the node objects as node attributes
# (This is so we can access them as objects)
@@ -180,8 +180,8 @@ class Primaite(Env):
plt.savefig(filename, format="PNG")
plt.clf()
except Exception:
logging.error("Could not save network diagram")
logging.error("Exception occured", exc_info=True)
_LOGGER.error("Could not save network diagram")
_LOGGER.error("Exception occured", exc_info=True)
print("Could not save network diagram")
# Define Observation Space
@@ -223,7 +223,7 @@ class Primaite(Env):
# Define Action Space - depends on action space type (Node or ACL)
if self.action_type == ACTION_TYPE.NODE:
logging.info("Action space type NODE selected")
_LOGGER.info("Action space type NODE selected")
# Terms (for node action space):
# [0, num nodes] - node ID (0 = nothing, node ID)
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
@@ -238,7 +238,7 @@ class Primaite(Env):
]
)
else:
logging.info("Action space type ACL selected")
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
@@ -273,10 +273,10 @@ class Primaite(Env):
self.csv_writer = csv.writer(self.csv_file)
self.csv_writer.writerow(header)
except Exception:
logging.error(
_LOGGER.error(
"Could not create csv file to hold average reward per episode"
)
logging.error("Exception occured", exc_info=True)
_LOGGER.error("Exception occured", exc_info=True)
def reset(self):
"""
@@ -322,7 +322,7 @@ class Primaite(Env):
step_info: Additional information relating to this step
"""
if self.step_count == 0:
print("Episode: " + str(self.episode_count) + " running")
print(f"Episode: {str(self.episode_count)}")
# TEMP
done = False
@@ -402,7 +402,7 @@ class Primaite(Env):
self.step_count,
self.config_values,
)
# print("Step reward: " + str(reward))
print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -410,7 +410,7 @@ class Primaite(Env):
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
print("Average reward: " + str(self.average_reward))
print(f" Average Reward: {str(self.average_reward)}")
# Load the reward into the transaction
transaction.set_reward(reward)
@@ -757,7 +757,7 @@ class Primaite(Env):
# Do nothing (bad formatting)
pass
logging.info("Environment configuration loaded")
_LOGGER.info("Environment configuration loaded")
print("Environment configuration loaded")
def create_node(self, item):
@@ -1090,7 +1090,7 @@ class Primaite(Env):
item: A config data item representing steps info
"""
self.episode_steps = int(steps_info["steps"])
logging.info("Training episodes have " + str(self.episode_steps) + " steps")
_LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps")
def reset_environment(self):
"""

View File

@@ -98,27 +98,27 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions)
# Need to compare initial and reference (current) state of node (i.e. at every step)
if initial_node_operating_state == HARDWARE_STATE.ON:
if final_node_operating_state == HARDWARE_STATE.OFF:
if reference_node_operating_state == HARDWARE_STATE.OFF:
score += config_values.off_should_be_on
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
score += config_values.resetting_should_be_on
else:
pass
elif initial_node_operating_state == HARDWARE_STATE.OFF:
if final_node_operating_state == HARDWARE_STATE.ON:
if reference_node_operating_state == HARDWARE_STATE.ON:
score += config_values.on_should_be_off
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
score += config_values.resetting_should_be_off
else:
pass
elif initial_node_operating_state == HARDWARE_STATE.RESETTING:
if final_node_operating_state == HARDWARE_STATE.ON:
if reference_node_operating_state == HARDWARE_STATE.ON:
score += config_values.on_should_be_resetting
elif final_node_operating_state == HARDWARE_STATE.OFF:
elif reference_node_operating_state == HARDWARE_STATE.OFF:
score += config_values.off_should_be_resetting
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
score += config_values.resetting
else:
pass
@@ -148,29 +148,29 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions)
# Need to compare initial and reference (current) state of node (i.e. at every step)
if initial_node_os_state == SOFTWARE_STATE.GOOD:
if final_node_os_state == SOFTWARE_STATE.PATCHING:
if reference_node_os_state == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_good
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_good
else:
pass
elif initial_node_os_state == SOFTWARE_STATE.PATCHING:
if final_node_os_state == SOFTWARE_STATE.GOOD:
if reference_node_os_state == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_patching
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
elif reference_node_os_state == SOFTWARE_STATE.PATCHING:
score += config_values.patching
else:
pass
elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED:
if final_node_os_state == SOFTWARE_STATE.GOOD:
if reference_node_os_state == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_compromised
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
elif reference_node_os_state == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_compromised
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised
else:
pass
@@ -204,46 +204,46 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions)
# Need to compare initial and reference state of node (i.e. at every step)
if initial_service.get_state() == SOFTWARE_STATE.GOOD:
if final_service.get_state() == SOFTWARE_STATE.PATCHING:
if reference_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_good
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_good
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed_should_be_good
else:
pass
elif initial_service.get_state() == SOFTWARE_STATE.PATCHING:
if final_service.get_state() == SOFTWARE_STATE.GOOD:
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_patching
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_patching
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching
else:
pass
elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED:
if final_service.get_state() == SOFTWARE_STATE.GOOD:
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_compromised
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_compromised
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised
else:
pass
elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
if final_service.get_state() == SOFTWARE_STATE.GOOD:
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
score += config_values.good_should_be_overwhelmed
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
score += config_values.patching_should_be_overwhelmed
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
score += config_values.overwhelmed
else:
pass
@@ -276,67 +276,67 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and final state of node (i.e. after red and blue actions)
# Need to compare initial and reference state of node (i.e. at every step)
if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
if final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
if reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
score += config_values.repairing_should_be_good
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_good
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_good
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_good
else:
pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
score += config_values.good_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_repairing
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
score += config_values.repairing
else:
pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
score += config_values.good_should_be_restoring
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
score += config_values.repairing_should_be_restoring
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_restoring
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_restoring
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring
else:
pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
score += config_values.good_should_be_corrupt
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
score += config_values.repairing_should_be_corrupt
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_corrupt
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed_should_be_corrupt
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt
else:
pass
elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
score += config_values.good_should_be_destroyed
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
score += config_values.repairing_should_be_destroyed
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
score += config_values.restoring_should_be_destroyed
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
score += config_values.corrupt_should_be_destroyed
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
score += config_values.destroyed
else:
pass

View File

@@ -165,7 +165,9 @@ def load_config_values():
config_values.agent_identifier = config_data["agentIdentifier"]
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = config_data["configFilename"]
config_values.config_filename_use_case = (
"config/" + config_data["configFilename"]
)
config_values.session_type = config_data["sessionType"]
config_values.load_agent = bool(config_data["loadAgent"])
config_values.agent_load_file = config_data["agentLoadFile"]

View File

@@ -0,0 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from pathlib import Path
from typing import Final
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
"The tests config root directory."

0
tests/config/__init__.py Normal file
View File

View File

@@ -0,0 +1,125 @@
- itemType: ACTIONS
type: NODE
- itemType: STEPS
steps: 13
- itemType: PORTS
portsList:
- port: '21'
- itemType: SERVICES
serviceList:
- name: ftp
- itemType: NODE
id: '1'
name: node
baseType: SERVICE
nodeType: COMPUTER
priority: P1
hardwareState: 'ON'
ipAddress: 192.168.0.1
softwareState: GOOD
fileSystemState: GOOD
services:
- name: ftp
port: '21'
state: GOOD
- itemType: POSITION
positions:
- node: '1'
x_pos: 309
y_pos: 78
- itemType: RED_POL
id: '1'
startStep: 1
endStep: 3
targetNodeId: '1'
initiator: DIRECT
type: FILE
protocol: NA
state: CORRUPT
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '2'
startStep: 3
endStep: 13
targetNodeId: '1'
initiator: DIRECT
type: FILE
protocol: NA
state: GOOD
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '3'
startStep: 4
endStep: 6
targetNodeId: '1'
initiator: DIRECT
type: OPERATING
protocol: NA
state: 'OFF'
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '4'
startStep: 6
endStep: 13
targetNodeId: '1'
initiator: DIRECT
type: OPERATING
protocol: NA
state: 'ON'
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '5'
startStep: 7
endStep: 9
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: ftp
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '6'
startStep: 9
endStep: 13
targetNodeId: '1'
initiator: DIRECT
type: SERVICE
protocol: ftp
state: GOOD
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '7'
startStep: 10
endStep: 12
targetNodeId: '1'
initiator: DIRECT
type: OS
protocol: NA
state: COMPROMISED
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA
- itemType: RED_POL
id: '8'
startStep: 12
endStep: 13
targetNodeId: '1'
initiator: DIRECT
type: OS
protocol: NA
state: GOOD
sourceNodeId: NA
sourceNodeService: NA
sourceNodeServiceState: NA

View File

@@ -0,0 +1,89 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: GENERIC
# Number of episodes to run per session
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1000000000
# Reward values
# Generic
allOk: 0
# Node Operating State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node O/S or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -1 +1,199 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
import time
from pathlib import Path
from typing import Union
import yaml
from primaite.common.config_values_main import config_values_main
from primaite.environment.primaite_env import Primaite
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
def _get_primaite_env_from_config(
main_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
):
"""Takes a config path and returns the created instance of Primaite."""
def load_config_values():
config_values.agent_identifier = config_data["agentIdentifier"]
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = lay_down_config_path
config_values.session_type = config_data["sessionType"]
config_values.load_agent = bool(config_data["loadAgent"])
config_values.agent_load_file = config_data["agentLoadFile"]
# Environment
config_values.observation_space_high_value = int(
config_data["observationSpaceHighValue"]
)
# Reward values
# Generic
config_values.all_ok = int(config_data["allOk"])
# Node Operating State
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
config_values.resetting = int(config_data["resetting"])
# Node O/S or Service State
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
config_values.good_should_be_compromised = int(
config_data["goodShouldBeCompromised"]
)
config_values.good_should_be_overwhelmed = int(
config_data["goodShouldBeOverwhelmed"]
)
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
config_values.patching_should_be_compromised = int(
config_data["patchingShouldBeCompromised"]
)
config_values.patching_should_be_overwhelmed = int(
config_data["patchingShouldBeOverwhelmed"]
)
config_values.patching = int(config_data["patching"])
config_values.compromised_should_be_good = int(
config_data["compromisedShouldBeGood"]
)
config_values.compromised_should_be_patching = int(
config_data["compromisedShouldBePatching"]
)
config_values.compromised_should_be_overwhelmed = int(
config_data["compromisedShouldBeOverwhelmed"]
)
config_values.compromised = int(config_data["compromised"])
config_values.overwhelmed_should_be_good = int(
config_data["overwhelmedShouldBeGood"]
)
config_values.overwhelmed_should_be_patching = int(
config_data["overwhelmedShouldBePatching"]
)
config_values.overwhelmed_should_be_compromised = int(
config_data["overwhelmedShouldBeCompromised"]
)
config_values.overwhelmed = int(config_data["overwhelmed"])
# Node File System State
config_values.good_should_be_repairing = int(
config_data["goodShouldBeRepairing"]
)
config_values.good_should_be_restoring = int(
config_data["goodShouldBeRestoring"]
)
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
config_values.good_should_be_destroyed = int(
config_data["goodShouldBeDestroyed"]
)
config_values.repairing_should_be_good = int(
config_data["repairingShouldBeGood"]
)
config_values.repairing_should_be_restoring = int(
config_data["repairingShouldBeRestoring"]
)
config_values.repairing_should_be_corrupt = int(
config_data["repairingShouldBeCorrupt"]
)
config_values.repairing_should_be_destroyed = int(
config_data["repairingShouldBeDestroyed"]
)
config_values.repairing = int(config_data["repairing"])
config_values.restoring_should_be_good = int(
config_data["restoringShouldBeGood"]
)
config_values.restoring_should_be_repairing = int(
config_data["restoringShouldBeRepairing"]
)
config_values.restoring_should_be_corrupt = int(
config_data["restoringShouldBeCorrupt"]
)
config_values.restoring_should_be_destroyed = int(
config_data["restoringShouldBeDestroyed"]
)
config_values.restoring = int(config_data["restoring"])
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
config_values.corrupt_should_be_repairing = int(
config_data["corruptShouldBeRepairing"]
)
config_values.corrupt_should_be_restoring = int(
config_data["corruptShouldBeRestoring"]
)
config_values.corrupt_should_be_destroyed = int(
config_data["corruptShouldBeDestroyed"]
)
config_values.corrupt = int(config_data["corrupt"])
config_values.destroyed_should_be_good = int(
config_data["destroyedShouldBeGood"]
)
config_values.destroyed_should_be_repairing = int(
config_data["destroyedShouldBeRepairing"]
)
config_values.destroyed_should_be_restoring = int(
config_data["destroyedShouldBeRestoring"]
)
config_values.destroyed_should_be_corrupt = int(
config_data["destroyedShouldBeCorrupt"]
)
config_values.destroyed = int(config_data["destroyed"])
config_values.scanning = int(config_data["scanning"])
# IER status
config_values.red_ier_running = int(config_data["redIerRunning"])
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
# Patching / Reset durations
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
config_values.service_patching_duration = int(
config_data["servicePatchingDuration"]
)
config_values.file_system_repairing_limit = int(
config_data["fileSystemRepairingLimit"]
)
config_values.file_system_restoring_limit = int(
config_data["fileSystemRestoringLimit"]
)
config_values.file_system_scanning_limit = int(
config_data["fileSystemScanningLimit"]
)
config_file_main = open(main_config_path, "r")
config_data = yaml.safe_load(config_file_main)
# Create a config class
config_values = config_values_main()
# Load in config data
load_config_values()
env = Primaite(config_values, [])
config_values.num_steps = env.episode_steps
if env.config_values.agent_identifier == "GENERIC":
run_generic(env, config_values)
return env
def run_generic(env, config_values):
"""Run against a generic agent."""
# Reset the environment at the start of the episode
# env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
# env.reset()
# env.close()

View File

@@ -0,0 +1,31 @@
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def test_rewards_are_being_penalised_at_each_step_function():
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
env = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_lay_down_config.yaml",
)
"""
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated:
File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3)
Hardware State: onShouldBeOff = -2 (between Steps 4 & 6)
Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9)
Operating System State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12)
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26
Step Count: 13
For the 4 steps where this occurs the average reward is:
Average Reward: 2 (26 / 13)
"""
print("average reward", env.average_reward)
assert env.average_reward == 2.0