Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test

This commit is contained in:
Czar Echavez
2023-07-03 16:56:44 +01:00
16 changed files with 527 additions and 79 deletions

View File

@@ -25,6 +25,12 @@ steps:
versionSpec: '$(python.version)' versionSpec: '$(python.version)'
displayName: 'Use Python $(python.version)' displayName: 'Use Python $(python.version)'
- script: |
python -m pip install pre-commit
pre-commit install
pre-commit run --all-files
displayName: 'Run pre-commits'
- script: | - script: |
python -m pip install --upgrade pip==23.0.1 python -m pip install --upgrade pip==23.0.1
pip install wheel==0.38.4 --upgrade pip install wheel==0.38.4 --upgrade

View File

@@ -28,6 +28,10 @@ The environment config file consists of the following attributes:
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent * STABLE_BASELINES3_PPO - Use a SB3 PPO agent
* STABLE_BASELINES3_A2C - use a SB3 A2C agent * STABLE_BASELINES3_A2C - use a SB3 A2C agent
* **random_red_agent** [bool]
Determines if the session should be run with a random red agent
* **action_type** [enum] * **action_type** [enum]
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session

View File

@@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session:
* Timestamp * Timestamp
* Episode number * Episode number
* Step number * Step number
* Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y * Initial observation space (what the blue agent observed when it decided its action)
* Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y
* Reward value * Reward value
* Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
**Diagrams** **Diagrams**

View File

@@ -6,11 +6,23 @@
# "STABLE_BASELINES3_A2C" # "STABLE_BASELINES3_A2C"
# "GENERIC" # "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C agent_identifier: STABLE_BASELINES3_A2C
# RED AGENT IDENTIFIER
# RANDOM or NONE
random_red_agent: False
# Sets How the Action Space is defined: # Sets How the Action Space is defined:
# "NODE" # "NODE"
# "ACL" # "ACL"
# "ANY" node and acl actions # "ANY" node and acl actions
action_type: NODE action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session # Number of episodes to run per session
num_episodes: 10 num_episodes: 10
# Number of time_steps per episode # Number of time_steps per episode

View File

@@ -0,0 +1,99 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# RED AGENT IDENTIFIER
# RANDOM or NONE
random_red_agent: True
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# Reward values
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
# IER status
red_ier_running: -5
green_ier_blocked: -10
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -21,6 +21,9 @@ class TrainingConfig:
agent_identifier: str = "STABLE_BASELINES3_A2C" agent_identifier: str = "STABLE_BASELINES3_A2C"
"The Red Agent algo/class to be used." "The Red Agent algo/class to be used."
random_red_agent: bool = False
"Creates Random Red Agent Attacks"
action_type: ActionType = ActionType.ANY action_type: ActionType = ActionType.ANY
"The ActionType to use." "The ActionType to use."

View File

@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
self.env: "Primaite" = env self.env: "Primaite" = env
self.space: spaces.Space self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive? self.current_observation: np.ndarray # type might be too restrictive?
self.structure: List[str]
return NotImplemented return NotImplemented
@abstractmethod @abstractmethod
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
"""Update the observation based on the current state of the environment.""" """Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented self.current_observation = NotImplemented
@abstractmethod
def generate_structure(self) -> List[str]:
"""Return a list of labels for the components of the flattened observation space."""
return NotImplemented
class NodeLinkTable(AbstractObservationComponent): class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols. """Table with nodes and links as rows and hardware/software status as cols.
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
# 3. Initialise Observation with zeroes # 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self): def update(self):
"""Update the observation based on current environment state. """Update the observation based on current environment state.
@@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_index += 1 protocol_index += 1
item_index += 1 item_index += 1
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
nodes = self.env.nodes.values()
links = self.env.links.values()
structure = []
for i, node in enumerate(nodes):
node_id = node.node_id
node_labels = [
f"node_{node_id}_id",
f"node_{node_id}_hardware_status",
f"node_{node_id}_os_status",
f"node_{node_id}_fs_status",
]
for j, serv in enumerate(self.env.services_list):
node_labels.append(f"node_{node_id}_service_{serv}_status")
structure.extend(node_labels)
for i, link in enumerate(links):
link_id = link.id
link_labels = [
f"link_{link_id}_id",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
f"link_{link_id}_n/a",
]
for j, serv in enumerate(self.env.services_list):
link_labels.append(f"link_{link_id}_service_{serv}_load")
structure.extend(link_labels)
return structure
class NodeStatuses(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states. """Flat list of nodes' hardware, OS, file system, and service states.
@@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent):
# 3. Initialise observation with zeroes # 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self): def update(self):
"""Update the observation based on current environment state. """Update the observation based on current environment state.
@@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent):
) )
self.current_observation[:] = obs self.current_observation[:] = obs
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
services = self.env.services_list
structure = []
for _, node in self.env.nodes.items():
node_id = node.node_id
structure.append(f"node_{node_id}_hardware_state_NONE")
for state in HardwareState:
structure.append(f"node_{node_id}_hardware_state_{state.name}")
structure.append(f"node_{node_id}_software_state_NONE")
for state in SoftwareState:
structure.append(f"node_{node_id}_software_state_{state.name}")
structure.append(f"node_{node_id}_file_system_state_NONE")
for state in FileSystemState:
structure.append(f"node_{node_id}_file_system_state_{state.name}")
for service in services:
structure.append(f"node_{node_id}_service_{service}_state_NONE")
for state in SoftwareState:
structure.append(
f"node_{node_id}_service_{service}_state_{state.name}"
)
return structure
class LinkTrafficLevels(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories. """Flat list of traffic levels encoded into banded categories.
@@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
# 3. Initialise observation with zeroes # 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
self.structure = self.generate_structure()
def update(self): def update(self):
"""Update the observation based on current environment state. """Update the observation based on current environment state.
@@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation[:] = obs self.current_observation[:] = obs
def generate_structure(self):
"""Return a list of labels for the components of the flattened observation space."""
structure = []
for _, link in self.env.links.items():
link_id = link.id
if self._combine_service_traffic:
protocols = ["overall"]
else:
protocols = [protocol.name for protocol in link.protocol_list]
for p in protocols:
for i in range(self._quantisation_levels):
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
return structure
class ObservationsHandler: class ObservationsHandler:
"""Component-based observation space handler. """Component-based observation space handler.
@@ -311,8 +395,17 @@ class ObservationsHandler:
def __init__(self): def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = [] self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray] # internal the observation space (unflattened version of space if flatten=True)
self._space: spaces.Space
# flattened version of the observation space
self._flat_space: spaces.Space
self._observation: Union[Tuple[np.ndarray], np.ndarray]
# used for transactions and when flatten=true
self._flat_observation: np.ndarray
self.flatten: bool = False
def update_obs(self): def update_obs(self):
"""Fetch fresh information about the environment.""" """Fetch fresh information about the environment."""
@@ -321,12 +414,11 @@ class ObservationsHandler:
obs.update() obs.update()
current_obs.append(obs.current_observation) current_obs.append(obs.current_observation)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1: if len(current_obs) == 1:
self.current_observation = current_obs[0] self._observation = current_obs[0]
else: else:
self.current_observation = tuple(current_obs) self._observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. self._flat_observation = spaces.flatten(self._space, self._observation)
def register(self, obs_component: AbstractObservationComponent): def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track. """Add a component for this handler to track.
@@ -353,12 +445,31 @@ class ObservationsHandler:
for obs_comp in self.registered_obs_components: for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space) component_spaces.append(obs_comp.space)
# If there is only one component, don't use a tuple space, just pass through that component's space. # if there are multiple components, build a composite tuple space
if len(component_spaces) == 1: if len(component_spaces) == 1:
self.space = component_spaces[0] self._space = component_spaces[0]
else: else:
self.space = spaces.Tuple(component_spaces) self._space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. if len(component_spaces) > 0:
self._flat_space = spaces.flatten_space(self._space)
else:
self._flat_space = spaces.Box(0, 1, (0,))
@property
def space(self):
"""Observation space, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_space
else:
return self._space
@property
def current_observation(self):
"""Current observation, return the flattened version if flatten is True."""
if self.flatten:
return self._flat_observation
else:
return self._observation
@classmethod @classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict): def from_config(cls, env: "Primaite", obs_space_config: dict):
@@ -388,6 +499,9 @@ class ObservationsHandler:
# Instantiate the handler # Instantiate the handler
handler = cls() handler = cls()
if obs_space_config.get("flatten"):
handler.flatten = True
for component_cfg in obs_space_config["components"]: for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component # Figure out which class can instantiate the desired component
comp_type = component_cfg["name"] comp_type = component_cfg["name"]
@@ -401,3 +515,17 @@ class ObservationsHandler:
handler.update_obs() handler.update_obs()
return handler return handler
def describe_structure(self):
"""Create a list of names for the features of the obs space.
The order of labels follows the flattened version of the space.
"""
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
# to fake it. each component has to just hard-code the expected label order after flattening...
labels = []
for obs_comp in self.registered_obs_components:
labels.extend(obs_comp.structure)
return labels

View File

@@ -3,8 +3,10 @@
import copy import copy
import csv import csv
import logging import logging
import uuid as uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from random import choice, randint, sample, uniform
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
import networkx as nx import networkx as nx
@@ -197,7 +199,6 @@ class Primaite(Env):
try: try:
plt.tight_layout() plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True) nx.draw_networkx(self.network, with_labels=True)
# now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png" file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG") plt.savefig(file_path, format="PNG")
@@ -281,6 +282,10 @@ class Primaite(Env):
# Does this for both live and reference nodes # Does this for both live and reference nodes
self.reset_environment() self.reset_environment()
# Create a random red agent to use for this episode
if self.training_config.random_red_agent:
self._create_random_red_agent()
# Reset counters and totals # Reset counters and totals
self.total_reward = 0 self.total_reward = 0
self.step_count = 0 self.step_count = 0
@@ -325,7 +330,8 @@ class Primaite(Env):
datetime.now(), self.agent_identifier, self.episode_count, self.step_count datetime.now(), self.agent_identifier, self.episode_count, self.step_count
) )
# Load the initial observation space into the transaction # Load the initial observation space into the transaction
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) transaction.set_obs_space(self.obs_handler._flat_observation)
# Load the action space into the transaction # Load the action space into the transaction
transaction.set_action_space(copy.deepcopy(action)) transaction.set_action_space(copy.deepcopy(action))
@@ -406,8 +412,6 @@ class Primaite(Env):
# 7. Update env_obs # 7. Update env_obs
self.update_environent_obs() self.update_environent_obs()
# Load the new observation space into the transaction
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
# 8. Add the transaction to the list of transactions # 8. Add the transaction to the list of transactions
self.transaction_list.append(copy.deepcopy(transaction)) self.transaction_list.append(copy.deepcopy(transaction))
@@ -1240,3 +1244,136 @@ class Primaite(Env):
# Combine the Node dict and ACL dict # Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict} combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict return combined_action_dict
def _create_random_red_agent(self):
"""Decide on random red agent for the episode to be called in env.reset()."""
# Reset the current red iers and red node pol
self.red_iers = {}
self.red_node_pol = {}
# Decide how many nodes become compromised
node_list = list(self.nodes.values())
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
max_num_nodes_compromised = len(
computers
) # only computers can become compromised
# random select between 1 and max_num_nodes_compromised
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
# Decide which of the nodes to compromise
nodes_to_be_compromised = sample(computers, num_nodes_to_compromise)
# choose a random compromise node to be source of attacks
source_node = choice(nodes_to_be_compromised)
# For each of the nodes to be compromised decide which step they become compromised
max_step_compromised = (
self.episode_steps // 2
) # always compromise in first half of episode
# Bandwidth for all links
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
if len(bandwidths) < 1:
msg = "Random red agent cannot be used on a network without any links"
_LOGGER.error(msg)
raise Exception(msg)
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
for n, node in enumerate(nodes_to_be_compromised):
# 1: Use Node PoL to set node to compromised
_id = str(uuid.uuid4())
_start_step = randint(2, max_step_compromised + 1) # step compromised
pol_service_name = choice(list(node.services.keys()))
source_node_service = choice(list(source_node.services.values()))
red_pol = NodeStateInstructionRed(
_id=_id,
_start_step=_start_step,
_end_step=_start_step, # only run for 1 step
_target_node_id=node.node_id,
_pol_initiator="DIRECT",
_pol_type=NodePOLType["SERVICE"],
pol_protocol=pol_service_name,
_pol_state=SoftwareState.COMPROMISED,
_pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state,
)
self.red_node_pol[_id] = red_pol
# 2: Launch the attack from compromised node - set the IER
ier_id = str(uuid.uuid4())
# Launch the attack after node is compromised, and not right at the end of the episode
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
ier_end_step = self.episode_steps
# Randomise the load, as a percentage of a random link bandwith
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
ier_protocol = pol_service_name # Same protocol as compromised node
ier_service = node.services[pol_service_name]
ier_port = ier_service.port
ier_mission_criticality = (
0 # Red IER will never be important to green agent success
)
# We choose a node to attack based on the first that applies:
# a. Green IERs, select dest node of the red ier based on dest node of green IER
# b. Attack a random server that doesn't have a DENY acl rule in default config
# c. Attack a random server
possible_ier_destinations = [
ier.get_dest_node_id()
for ier in list(self.green_iers.values())
if ier.get_source_node_id() == node.node_id
]
if len(possible_ier_destinations) < 1:
for server in servers:
if not self.acl.is_blocked(
node.ip_address,
server.ip_address,
ier_service,
ier_port,
):
possible_ier_destinations.append(server.node_id)
if len(possible_ier_destinations) < 1:
# If still none found choose from all servers
possible_ier_destinations = [server.node_id for server in servers]
ier_dest = choice(possible_ier_destinations)
self.red_iers[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
node.node_id,
ier_dest,
ier_mission_criticality,
)
overwhelm_pol = red_pol
overwhelm_pol.id = str(uuid.uuid4())
overwhelm_pol.end_step = self.episode_steps
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
# # TODO remove duplicate red pol for same targetted service - must take into account start step
o_pol_id = str(uuid.uuid4())
o_red_pol = NodeStateInstructionRed(
_id=o_pol_id,
_start_step=ier_start_step,
_end_step=self.episode_steps,
_target_node_id=ier_dest,
_pol_initiator="DIRECT",
_pol_type=NodePOLType["SERVICE"],
pol_protocol=ier_protocol,
_pol_state=SoftwareState.OVERWHELMED,
_pol_source_node_id=source_node.node_id,
_pol_source_node_service=source_node_service.name,
_pol_source_node_service_state=source_node_service.software_state,
)
self.red_node_pol[o_pol_id] = o_red_pol

View File

@@ -78,8 +78,8 @@ def calculate_reward_function(
start_step = ier_value.get_start_step() start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step() stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step: if step_count >= start_step and step_count <= stop_step:
reference_blocked = reference_ier.get_is_running() reference_blocked = not reference_ier.get_is_running()
live_blocked = ier_value.get_is_running() live_blocked = not ier_value.get_is_running()
ier_reward = ( ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality() config_values.green_ier_blocked * ier_value.get_mission_criticality()
) )

View File

@@ -354,6 +354,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
transaction_list=transaction_list, transaction_list=transaction_list,
session_path=session_dir, session_path=session_dir,
timestamp_str=timestamp_str, timestamp_str=timestamp_str,
obs_space_description=env.obs_handler.describe_structure(),
) )
print("Updating Session Metadata file...") print("Updating Session Metadata file...")

View File

@@ -1,8 +1,11 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""Defines node behaviour for Green PoL.""" """Defines node behaviour for Green PoL."""
from dataclasses import dataclass
from primaite.common.enums import NodePOLType from primaite.common.enums import NodePOLType
@dataclass()
class NodeStateInstructionRed(object): class NodeStateInstructionRed(object):
"""The Node State Instruction class.""" """The Node State Instruction class."""

View File

@@ -190,14 +190,14 @@ class ServiceNode(ActiveNode):
service_value.reduce_patching_count() service_value.reduce_patching_count()
def update_resetting_status(self): def update_resetting_status(self):
"""Updates the resetting counter for any service that are resetting.""" """Update resetting counter and set software state if it reached 0."""
super().update_resetting_status() super().update_resetting_status()
if self.resetting_count <= 0: if self.resetting_count <= 0:
for service in self.services.values(): for service in self.services.values():
service.software_state = SoftwareState.GOOD service.software_state = SoftwareState.GOOD
def update_booting_status(self): def update_booting_status(self):
"""Updates the booting counter for any service that are booting up.""" """Update booting counter and set software to good if it reached 0."""
super().update_booting_status() super().update_booting_status()
if self.booting_count <= 0: if self.booting_count <= 0:
for service in self.services.values(): for service in self.services.values():

View File

@@ -20,23 +20,14 @@ class Transaction(object):
self.episode_number = _episode_number self.episode_number = _episode_number
self.step_number = _step_number self.step_number = _step_number
def set_obs_space_pre(self, _obs_space_pre): def set_obs_space(self, _obs_space):
""" """
Sets the observation space (pre). Sets the observation space (pre).
Args: Args:
_obs_space_pre: The observation space before any actions are taken _obs_space_pre: The observation space before any actions are taken
""" """
self.obs_space_pre = _obs_space_pre self.obs_space = _obs_space
def set_obs_space_post(self, _obs_space_post):
"""
Sets the observation space (post).
Args:
_obs_space_post: The observation space after any actions are taken
"""
self.obs_space_post = _obs_space_post
def set_reward(self, _reward): def set_reward(self, _reward):
""" """

View File

@@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space):
return [str(_action_space)] return [str(_action_space)]
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): def write_transaction_to_file(
""" transaction_list,
Turns observation space into a string array so it can be saved to csv. session_path: Path,
timestamp_str: str,
Args: obs_space_description: list,
_obs_space: The observation space ):
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
_obs_features: The number of features associated with the asset
"""
return_array = []
for x in range(_obs_assets):
for y in range(_obs_features):
return_array.append(str(_obs_space[x][y]))
return return_array
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
""" """
Writes transaction logs to file to support training evaluation. Writes transaction logs to file to support training evaluation.
@@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
# This will be tied into the PrimAITE Use Case so that they make sense # This will be tied into the PrimAITE Use Case so that they make sense
template_transation = transaction_list[0] template_transation = transaction_list[0]
action_length = template_transation.action_space.size action_length = template_transation.action_space.size
obs_shape = template_transation.obs_space_post.shape # obs_shape = template_transation.obs_space_post.shape
obs_assets = template_transation.obs_space_post.shape[0] # obs_assets = template_transation.obs_space_post.shape[0]
if len(obs_shape) == 1: # if len(obs_shape) == 1:
# bit of a workaround but I think the way transactions are written will change soon # bit of a workaround but I think the way transactions are written will change soon
obs_features = 1 # obs_features = 1
else: # else:
obs_features = template_transation.obs_space_post.shape[1] # obs_features = template_transation.obs_space_post.shape[1]
# Create the action space headers array # Create the action space headers array
action_header = [] action_header = []
@@ -70,16 +58,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
action_header.append("AS_" + str(x)) action_header.append("AS_" + str(x))
# Create the observation space headers array # Create the observation space headers array
obs_header_initial = [] # obs_header_initial = [f"pre_{o}" for o in obs_space_description]
obs_header_new = [] # obs_header_new = [f"post_{o}" for o in obs_space_description]
for x in range(obs_assets):
for y in range(obs_features):
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
# Open up a csv file # Open up a csv file
header = ["Timestamp", "Episode", "Step", "Reward"] header = ["Timestamp", "Episode", "Step", "Reward"]
header = header + action_header + obs_header_initial + obs_header_new header = header + action_header + obs_space_description
try: try:
filename = session_path / f"all_transactions_{timestamp_str}.csv" filename = session_path / f"all_transactions_{timestamp_str}.csv"
@@ -98,12 +82,7 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
csv_data = ( csv_data = (
csv_data csv_data
+ turn_action_space_to_array(transaction.action_space) + turn_action_space_to_array(transaction.action_space)
+ turn_obs_space_to_array( + transaction.obs_space.tolist()
transaction.obs_space_pre, obs_assets, obs_features
)
+ turn_obs_space_to_array(
transaction.obs_space_post, obs_assets, obs_features
)
) )
csv_writer.writerow(csv_data) csv_writer.writerow(csv_data)

View File

@@ -0,0 +1,77 @@
from datetime import datetime
from primaite.config.lay_down_config import data_manipulation_config_path
from primaite.environment.primaite_env import Primaite
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_temp_session_path
def run_generic(env, config_values):
"""Run against a generic agent."""
# Reset the environment at the start of the episode
env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Reset the environment at the end of the episode
env.reset()
env.close()
def test_random_red_agent_behaviour():
"""
Test that hardware state is penalised at each step.
When the initial state is OFF compared to reference state which is ON.
"""
list_of_node_instructions = []
# RUN TWICE so we can make sure that red agent is randomised
for i in range(2):
"""Takes a config path and returns the created instance of Primaite."""
session_timestamp: datetime = datetime.now()
session_path = _get_temp_session_path(session_timestamp)
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
env = Primaite(
training_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_main_config.yaml",
lay_down_config_path=data_manipulation_config_path(),
transaction_list=[],
session_path=session_path,
timestamp_str=timestamp_str,
)
# set red_agent_
env.training_config.random_red_agent = True
training_config = env.training_config
training_config.num_steps = env.episode_steps
run_generic(env, training_config)
# add red pol instructions to list
list_of_node_instructions.append(env.red_node_pol)
# compare instructions to make sure that red instructions are truly random
for index, instruction in enumerate(list_of_node_instructions):
for key in list_of_node_instructions[index].keys():
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
print(f"run {index}")
print(f"{key} start step: {instruction.get_start_step()}")
print(f"{key} end step: {instruction.get_end_step()}")
print(f"{key} target node id: {instruction.get_target_node_id()}")
print("")
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])

View File

@@ -16,17 +16,26 @@ def test_rewards_are_being_penalised_at_each_step_function():
) )
""" """
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps:
File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3) On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute.
Hardware State: onShouldBeOff = -2 (between Steps 4 & 6) For example, turning the nodes' file system state to CORRUPT from its original state GOOD.
Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9) As a result these are the following rewards are activated:
Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12) File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 & 2)
Hardware State: off_should_be_on = -10 * 2 (on Steps 4 & 5)
Service State: compromised_should_be_good = -20 * 2 (on Steps 7 & 8)
Software State: compromised_should_be_good = -20 * 2 (on Steps 10 & 11)
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice.
Step Count: 13
Note: This test run inherits from conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded
to do NOTHING on every step.
We use Pattern of Lifes (PoLs) to change the nodes states and display that the agent is being penalised on all steps
where the live network node differs from the network reference node.
Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120
Step Count: 15
For the 4 steps where this occurs the average reward is: For the 4 steps where this occurs the average reward is:
Average Reward: 2 (26 / 13) Average Reward: -8 (-120 / 15)
""" """
print("average reward", env.average_reward)
assert env.average_reward == -8.0 assert env.average_reward == -8.0