#902 - replaced 'final_node_<placeholder>' with 'reference_node_<placeholder>' in methods for scoring of os_state, file_system_state, service state and operating state. This fixed the reward function so it is checked at each step for node operating system state, operating state, file system state and service state.
- Added unit tests.
This commit is contained in:
@@ -37,6 +37,8 @@ from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||
@@ -145,14 +147,12 @@ class Primaite(Env):
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
try:
|
||||
self.config_file = open(
|
||||
"config/" + self.config_values.config_filename_use_case, "r"
|
||||
)
|
||||
self.config_file = open(self.config_values.config_filename_use_case, "r")
|
||||
self.config_data = yaml.safe_load(self.config_file)
|
||||
self.load_config()
|
||||
except Exception:
|
||||
logging.error("Could not load the environment configuration")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not load the environment configuration")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
|
||||
# Store the node objects as node attributes
|
||||
# (This is so we can access them as objects)
|
||||
@@ -180,8 +180,8 @@ class Primaite(Env):
|
||||
plt.savefig(filename, format="PNG")
|
||||
plt.clf()
|
||||
except Exception:
|
||||
logging.error("Could not save network diagram")
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Could not save network diagram")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
print("Could not save network diagram")
|
||||
|
||||
# Define Observation Space
|
||||
@@ -223,7 +223,7 @@ class Primaite(Env):
|
||||
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
if self.action_type == ACTION_TYPE.NODE:
|
||||
logging.info("Action space type NODE selected")
|
||||
_LOGGER.info("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, o/s state, service state, file system state)
|
||||
@@ -238,7 +238,7 @@ class Primaite(Env):
|
||||
]
|
||||
)
|
||||
else:
|
||||
logging.info("Action space type ACL selected")
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
@@ -273,10 +273,10 @@ class Primaite(Env):
|
||||
self.csv_writer = csv.writer(self.csv_file)
|
||||
self.csv_writer.writerow(header)
|
||||
except Exception:
|
||||
logging.error(
|
||||
_LOGGER.error(
|
||||
"Could not create csv file to hold average reward per episode"
|
||||
)
|
||||
logging.error("Exception occured", exc_info=True)
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -322,7 +322,7 @@ class Primaite(Env):
|
||||
step_info: Additional information relating to this step
|
||||
"""
|
||||
if self.step_count == 0:
|
||||
print("Episode: " + str(self.episode_count) + " running")
|
||||
print(f"Episode: {str(self.episode_count)}")
|
||||
|
||||
# TEMP
|
||||
done = False
|
||||
@@ -402,7 +402,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.config_values,
|
||||
)
|
||||
# print("Step reward: " + str(reward))
|
||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -410,7 +410,7 @@ class Primaite(Env):
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
print("Average reward: " + str(self.average_reward))
|
||||
print(f" Average Reward: {str(self.average_reward)}")
|
||||
# Load the reward into the transaction
|
||||
transaction.set_reward(reward)
|
||||
|
||||
@@ -757,7 +757,7 @@ class Primaite(Env):
|
||||
# Do nothing (bad formatting)
|
||||
pass
|
||||
|
||||
logging.info("Environment configuration loaded")
|
||||
_LOGGER.info("Environment configuration loaded")
|
||||
print("Environment configuration loaded")
|
||||
|
||||
def create_node(self, item):
|
||||
@@ -1090,7 +1090,7 @@ class Primaite(Env):
|
||||
item: A config data item representing steps info
|
||||
"""
|
||||
self.episode_steps = int(steps_info["steps"])
|
||||
logging.info("Training episodes have " + str(self.episode_steps) + " steps")
|
||||
_LOGGER.info("Training episodes have " + str(self.episode_steps) + " steps")
|
||||
|
||||
def reset_environment(self):
|
||||
"""
|
||||
|
||||
@@ -98,27 +98,27 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
# Need to compare initial and reference (current) state of node (i.e. at every step)
|
||||
if initial_node_operating_state == HARDWARE_STATE.ON:
|
||||
if final_node_operating_state == HARDWARE_STATE.OFF:
|
||||
if reference_node_operating_state == HARDWARE_STATE.OFF:
|
||||
score += config_values.off_should_be_on
|
||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
score += config_values.resetting_should_be_on
|
||||
else:
|
||||
pass
|
||||
elif initial_node_operating_state == HARDWARE_STATE.OFF:
|
||||
if final_node_operating_state == HARDWARE_STATE.ON:
|
||||
if reference_node_operating_state == HARDWARE_STATE.ON:
|
||||
score += config_values.on_should_be_off
|
||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
score += config_values.resetting_should_be_off
|
||||
else:
|
||||
pass
|
||||
elif initial_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
if final_node_operating_state == HARDWARE_STATE.ON:
|
||||
if reference_node_operating_state == HARDWARE_STATE.ON:
|
||||
score += config_values.on_should_be_resetting
|
||||
elif final_node_operating_state == HARDWARE_STATE.OFF:
|
||||
elif reference_node_operating_state == HARDWARE_STATE.OFF:
|
||||
score += config_values.off_should_be_resetting
|
||||
elif final_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
elif reference_node_operating_state == HARDWARE_STATE.RESETTING:
|
||||
score += config_values.resetting
|
||||
else:
|
||||
pass
|
||||
@@ -148,29 +148,29 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
# Need to compare initial and reference (current) state of node (i.e. at every step)
|
||||
if initial_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
if final_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
if reference_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif initial_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
if final_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
if reference_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
elif reference_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif initial_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
if final_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
if reference_node_os_state == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
elif reference_node_os_state == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_node_os_state == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
else:
|
||||
pass
|
||||
@@ -204,46 +204,46 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
# Need to compare initial and reference state of node (i.e. at every step)
|
||||
if initial_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
if final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
if reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif initial_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_patching
|
||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif initial_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_compromised
|
||||
else:
|
||||
pass
|
||||
elif initial_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
if final_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
if reference_service.get_state() == SOFTWARE_STATE.GOOD:
|
||||
score += config_values.good_should_be_overwhelmed
|
||||
elif final_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.PATCHING:
|
||||
score += config_values.patching_should_be_overwhelmed
|
||||
elif final_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.COMPROMISED:
|
||||
score += config_values.compromised_should_be_overwhelmed
|
||||
elif final_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
elif reference_service.get_state() == SOFTWARE_STATE.OVERWHELMED:
|
||||
score += config_values.overwhelmed
|
||||
else:
|
||||
pass
|
||||
@@ -276,67 +276,67 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare initial and final state of node (i.e. after red and blue actions)
|
||||
# Need to compare initial and reference state of node (i.e. at every step)
|
||||
if initial_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
if reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
score += config_values.repairing_should_be_good
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_good
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_good
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
score += config_values.good_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_repairing
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
score += config_values.repairing
|
||||
else:
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
score += config_values.good_should_be_restoring
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
score += config_values.repairing_should_be_restoring
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_restoring
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_restoring
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring
|
||||
else:
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
score += config_values.good_should_be_corrupt
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
score += config_values.repairing_should_be_corrupt
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_corrupt
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed_should_be_corrupt
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt
|
||||
else:
|
||||
pass
|
||||
elif initial_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
if final_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
if reference_node_file_system_state == FILE_SYSTEM_STATE.GOOD:
|
||||
score += config_values.good_should_be_destroyed
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.REPAIRING:
|
||||
score += config_values.repairing_should_be_destroyed
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.RESTORING:
|
||||
score += config_values.restoring_should_be_destroyed
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.CORRUPT:
|
||||
score += config_values.corrupt_should_be_destroyed
|
||||
elif final_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
elif reference_node_file_system_state == FILE_SYSTEM_STATE.DESTROYED:
|
||||
score += config_values.destroyed
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -165,7 +165,9 @@ def load_config_values():
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = config_data["configFilename"]
|
||||
config_values.config_filename_use_case = (
|
||||
"config/" + config_data["configFilename"]
|
||||
)
|
||||
config_values.session_type = config_data["sessionType"]
|
||||
config_values.load_agent = bool(config_data["loadAgent"])
|
||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
TEST_CONFIG_ROOT: Final[Path] = Path(__file__).parent / "config"
|
||||
"The tests config root directory."
|
||||
|
||||
0
tests/config/__init__.py
Normal file
0
tests/config/__init__.py
Normal file
125
tests/config/one_node_states_on_off_lay_down_config.yaml
Normal file
125
tests/config/one_node_states_on_off_lay_down_config.yaml
Normal file
@@ -0,0 +1,125 @@
|
||||
- itemType: ACTIONS
|
||||
type: NODE
|
||||
- itemType: STEPS
|
||||
steps: 13
|
||||
- itemType: PORTS
|
||||
portsList:
|
||||
- port: '21'
|
||||
- itemType: SERVICES
|
||||
serviceList:
|
||||
- name: ftp
|
||||
- itemType: NODE
|
||||
id: '1'
|
||||
name: node
|
||||
baseType: SERVICE
|
||||
nodeType: COMPUTER
|
||||
priority: P1
|
||||
hardwareState: 'ON'
|
||||
ipAddress: 192.168.0.1
|
||||
softwareState: GOOD
|
||||
fileSystemState: GOOD
|
||||
services:
|
||||
- name: ftp
|
||||
port: '21'
|
||||
state: GOOD
|
||||
- itemType: POSITION
|
||||
positions:
|
||||
- node: '1'
|
||||
x_pos: 309
|
||||
y_pos: 78
|
||||
- itemType: RED_POL
|
||||
id: '1'
|
||||
startStep: 1
|
||||
endStep: 3
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: FILE
|
||||
protocol: NA
|
||||
state: CORRUPT
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '2'
|
||||
startStep: 3
|
||||
endStep: 13
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: FILE
|
||||
protocol: NA
|
||||
state: GOOD
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '3'
|
||||
startStep: 4
|
||||
endStep: 6
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: OPERATING
|
||||
protocol: NA
|
||||
state: 'OFF'
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '4'
|
||||
startStep: 6
|
||||
endStep: 13
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: OPERATING
|
||||
protocol: NA
|
||||
state: 'ON'
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '5'
|
||||
startStep: 7
|
||||
endStep: 9
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: ftp
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '6'
|
||||
startStep: 9
|
||||
endStep: 13
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: ftp
|
||||
state: GOOD
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '7'
|
||||
startStep: 10
|
||||
endStep: 12
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: OS
|
||||
protocol: NA
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- itemType: RED_POL
|
||||
id: '8'
|
||||
startStep: 12
|
||||
endStep: 13
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: OS
|
||||
protocol: NA
|
||||
state: GOOD
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
89
tests/config/one_node_states_on_off_main_config.yaml
Normal file
89
tests/config/one_node_states_on_off_main_config.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agentIdentifier: GENERIC
|
||||
# Number of episodes to run per session
|
||||
numEpisodes: 1
|
||||
# Time delay between steps (for generic agents)
|
||||
timeDelay: 1
|
||||
# Filename of the scenario / laydown
|
||||
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
sessionType: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
loadAgent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observationSpaceHighValue: 1000000000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
allOk: 0
|
||||
# Node Operating State
|
||||
offShouldBeOn: -10
|
||||
offShouldBeResetting: -5
|
||||
onShouldBeOff: -2
|
||||
onShouldBeResetting: -5
|
||||
resettingShouldBeOn: -5
|
||||
resettingShouldBeOff: -2
|
||||
resetting: -3
|
||||
# Node O/S or Service State
|
||||
goodShouldBePatching: 2
|
||||
goodShouldBeCompromised: 5
|
||||
goodShouldBeOverwhelmed: 5
|
||||
patchingShouldBeGood: -5
|
||||
patchingShouldBeCompromised: 2
|
||||
patchingShouldBeOverwhelmed: 2
|
||||
patching: -3
|
||||
compromisedShouldBeGood: -20
|
||||
compromisedShouldBePatching: -20
|
||||
compromisedShouldBeOverwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmedShouldBeGood: -20
|
||||
overwhelmedShouldBePatching: -20
|
||||
overwhelmedShouldBeCompromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
goodShouldBeRepairing: 2
|
||||
goodShouldBeRestoring: 2
|
||||
goodShouldBeCorrupt: 5
|
||||
goodShouldBeDestroyed: 10
|
||||
repairingShouldBeGood: -5
|
||||
repairingShouldBeRestoring: 2
|
||||
repairingShouldBeCorrupt: 2
|
||||
repairingShouldBeDestroyed: 0
|
||||
repairing: -3
|
||||
restoringShouldBeGood: -10
|
||||
restoringShouldBeRepairing: -2
|
||||
restoringShouldBeCorrupt: 1
|
||||
restoringShouldBeDestroyed: 2
|
||||
restoring: -6
|
||||
corruptShouldBeGood: -10
|
||||
corruptShouldBeRepairing: -10
|
||||
corruptShouldBeRestoring: -10
|
||||
corruptShouldBeDestroyed: 2
|
||||
corrupt: -10
|
||||
destroyedShouldBeGood: -20
|
||||
destroyedShouldBeRepairing: -20
|
||||
destroyedShouldBeRestoring: -20
|
||||
destroyedShouldBeCorrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
redIerRunning: -5
|
||||
greenIerBlocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
osPatchingDuration: 5 # The time taken to patch the OS
|
||||
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||
servicePatchingDuration: 5 # The time taken to patch a service
|
||||
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||
@@ -1 +1,199 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.common.config_values_main import config_values_main
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
|
||||
def _get_primaite_env_from_config(
|
||||
main_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]
|
||||
):
|
||||
"""Takes a config path and returns the created instance of Primaite."""
|
||||
|
||||
def load_config_values():
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = lay_down_config_path
|
||||
config_values.session_type = config_data["sessionType"]
|
||||
config_values.load_agent = bool(config_data["loadAgent"])
|
||||
config_values.agent_load_file = config_data["agentLoadFile"]
|
||||
# Environment
|
||||
config_values.observation_space_high_value = int(
|
||||
config_data["observationSpaceHighValue"]
|
||||
)
|
||||
# Reward values
|
||||
# Generic
|
||||
config_values.all_ok = int(config_data["allOk"])
|
||||
# Node Operating State
|
||||
config_values.off_should_be_on = int(config_data["offShouldBeOn"])
|
||||
config_values.off_should_be_resetting = int(config_data["offShouldBeResetting"])
|
||||
config_values.on_should_be_off = int(config_data["onShouldBeOff"])
|
||||
config_values.on_should_be_resetting = int(config_data["onShouldBeResetting"])
|
||||
config_values.resetting_should_be_on = int(config_data["resettingShouldBeOn"])
|
||||
config_values.resetting_should_be_off = int(config_data["resettingShouldBeOff"])
|
||||
config_values.resetting = int(config_data["resetting"])
|
||||
# Node O/S or Service State
|
||||
config_values.good_should_be_patching = int(config_data["goodShouldBePatching"])
|
||||
config_values.good_should_be_compromised = int(
|
||||
config_data["goodShouldBeCompromised"]
|
||||
)
|
||||
config_values.good_should_be_overwhelmed = int(
|
||||
config_data["goodShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching_should_be_good = int(config_data["patchingShouldBeGood"])
|
||||
config_values.patching_should_be_compromised = int(
|
||||
config_data["patchingShouldBeCompromised"]
|
||||
)
|
||||
config_values.patching_should_be_overwhelmed = int(
|
||||
config_data["patchingShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.patching = int(config_data["patching"])
|
||||
config_values.compromised_should_be_good = int(
|
||||
config_data["compromisedShouldBeGood"]
|
||||
)
|
||||
config_values.compromised_should_be_patching = int(
|
||||
config_data["compromisedShouldBePatching"]
|
||||
)
|
||||
config_values.compromised_should_be_overwhelmed = int(
|
||||
config_data["compromisedShouldBeOverwhelmed"]
|
||||
)
|
||||
config_values.compromised = int(config_data["compromised"])
|
||||
config_values.overwhelmed_should_be_good = int(
|
||||
config_data["overwhelmedShouldBeGood"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_patching = int(
|
||||
config_data["overwhelmedShouldBePatching"]
|
||||
)
|
||||
config_values.overwhelmed_should_be_compromised = int(
|
||||
config_data["overwhelmedShouldBeCompromised"]
|
||||
)
|
||||
config_values.overwhelmed = int(config_data["overwhelmed"])
|
||||
# Node File System State
|
||||
config_values.good_should_be_repairing = int(
|
||||
config_data["goodShouldBeRepairing"]
|
||||
)
|
||||
config_values.good_should_be_restoring = int(
|
||||
config_data["goodShouldBeRestoring"]
|
||||
)
|
||||
config_values.good_should_be_corrupt = int(config_data["goodShouldBeCorrupt"])
|
||||
config_values.good_should_be_destroyed = int(
|
||||
config_data["goodShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing_should_be_good = int(
|
||||
config_data["repairingShouldBeGood"]
|
||||
)
|
||||
config_values.repairing_should_be_restoring = int(
|
||||
config_data["repairingShouldBeRestoring"]
|
||||
)
|
||||
config_values.repairing_should_be_corrupt = int(
|
||||
config_data["repairingShouldBeCorrupt"]
|
||||
)
|
||||
config_values.repairing_should_be_destroyed = int(
|
||||
config_data["repairingShouldBeDestroyed"]
|
||||
)
|
||||
config_values.repairing = int(config_data["repairing"])
|
||||
config_values.restoring_should_be_good = int(
|
||||
config_data["restoringShouldBeGood"]
|
||||
)
|
||||
config_values.restoring_should_be_repairing = int(
|
||||
config_data["restoringShouldBeRepairing"]
|
||||
)
|
||||
config_values.restoring_should_be_corrupt = int(
|
||||
config_data["restoringShouldBeCorrupt"]
|
||||
)
|
||||
config_values.restoring_should_be_destroyed = int(
|
||||
config_data["restoringShouldBeDestroyed"]
|
||||
)
|
||||
config_values.restoring = int(config_data["restoring"])
|
||||
config_values.corrupt_should_be_good = int(config_data["corruptShouldBeGood"])
|
||||
config_values.corrupt_should_be_repairing = int(
|
||||
config_data["corruptShouldBeRepairing"]
|
||||
)
|
||||
config_values.corrupt_should_be_restoring = int(
|
||||
config_data["corruptShouldBeRestoring"]
|
||||
)
|
||||
config_values.corrupt_should_be_destroyed = int(
|
||||
config_data["corruptShouldBeDestroyed"]
|
||||
)
|
||||
config_values.corrupt = int(config_data["corrupt"])
|
||||
config_values.destroyed_should_be_good = int(
|
||||
config_data["destroyedShouldBeGood"]
|
||||
)
|
||||
config_values.destroyed_should_be_repairing = int(
|
||||
config_data["destroyedShouldBeRepairing"]
|
||||
)
|
||||
config_values.destroyed_should_be_restoring = int(
|
||||
config_data["destroyedShouldBeRestoring"]
|
||||
)
|
||||
config_values.destroyed_should_be_corrupt = int(
|
||||
config_data["destroyedShouldBeCorrupt"]
|
||||
)
|
||||
config_values.destroyed = int(config_data["destroyed"])
|
||||
config_values.scanning = int(config_data["scanning"])
|
||||
# IER status
|
||||
config_values.red_ier_running = int(config_data["redIerRunning"])
|
||||
config_values.green_ier_blocked = int(config_data["greenIerBlocked"])
|
||||
# Patching / Reset durations
|
||||
config_values.os_patching_duration = int(config_data["osPatchingDuration"])
|
||||
config_values.node_reset_duration = int(config_data["nodeResetDuration"])
|
||||
config_values.service_patching_duration = int(
|
||||
config_data["servicePatchingDuration"]
|
||||
)
|
||||
config_values.file_system_repairing_limit = int(
|
||||
config_data["fileSystemRepairingLimit"]
|
||||
)
|
||||
config_values.file_system_restoring_limit = int(
|
||||
config_data["fileSystemRestoringLimit"]
|
||||
)
|
||||
config_values.file_system_scanning_limit = int(
|
||||
config_data["fileSystemScanningLimit"]
|
||||
)
|
||||
|
||||
config_file_main = open(main_config_path, "r")
|
||||
config_data = yaml.safe_load(config_file_main)
|
||||
# Create a config class
|
||||
config_values = config_values_main()
|
||||
# Load in config data
|
||||
load_config_values()
|
||||
env = Primaite(config_values, [])
|
||||
config_values.num_steps = env.episode_steps
|
||||
|
||||
if env.config_values.agent_identifier == "GENERIC":
|
||||
run_generic(env, config_values)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def run_generic(env, config_values):
|
||||
"""Run against a generic agent."""
|
||||
# Reset the environment at the start of the episode
|
||||
# env.reset()
|
||||
for episode in range(0, config_values.num_episodes):
|
||||
for step in range(0, config_values.num_steps):
|
||||
# Send the observation space to the agent to get an action
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Run the simulation step on the live environment
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
# Break if done is True
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(config_values.time_delay / 1000)
|
||||
|
||||
# Reset the environment at the end of the episode
|
||||
# env.reset()
|
||||
|
||||
# env.close()
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
from tests.conftest import _get_primaite_env_from_config
|
||||
|
||||
|
||||
def test_rewards_are_being_penalised_at_each_step_function():
|
||||
"""
|
||||
Test that hardware state is penalised at each step.
|
||||
|
||||
When the initial state is OFF compared to reference state which is ON.
|
||||
"""
|
||||
env = _get_primaite_env_from_config(
|
||||
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT
|
||||
/ "one_node_states_on_off_lay_down_config.yaml",
|
||||
)
|
||||
|
||||
"""
|
||||
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated:
|
||||
File System State: goodShouldBeCorrupt = 5 (Step 3)
|
||||
Hardware State: onShouldBeOff = -2 (Step 5)
|
||||
Service State: goodShouldBeCompromised = 5 (Step 7)
|
||||
Operating System State (Software State): goodShouldBeCompromised = 5 (Step 10)
|
||||
|
||||
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26
|
||||
Step Count: 13
|
||||
|
||||
For the 4 steps where this occurs the average reward is:
|
||||
Average Reward: 2 (26 / 13)
|
||||
"""
|
||||
print("average reward", env.average_reward)
|
||||
assert env.average_reward == 2.0
|
||||
|
||||
Reference in New Issue
Block a user