Merge branch 'dev' into feature/1386-enable-a-repeatable-or-deterministic-baseline-test
This commit is contained in:
@@ -25,6 +25,12 @@ steps:
|
|||||||
versionSpec: '$(python.version)'
|
versionSpec: '$(python.version)'
|
||||||
displayName: 'Use Python $(python.version)'
|
displayName: 'Use Python $(python.version)'
|
||||||
|
|
||||||
|
- script: |
|
||||||
|
python -m pip install pre-commit
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
|
displayName: 'Run pre-commits'
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
python -m pip install --upgrade pip==23.0.1
|
python -m pip install --upgrade pip==23.0.1
|
||||||
pip install wheel==0.38.4 --upgrade
|
pip install wheel==0.38.4 --upgrade
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ The environment config file consists of the following attributes:
|
|||||||
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
|
||||||
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
|
||||||
|
|
||||||
|
* **random_red_agent** [bool]
|
||||||
|
|
||||||
|
Determines if the session should be run with a random red agent
|
||||||
|
|
||||||
* **action_type** [enum]
|
* **action_type** [enum]
|
||||||
|
|
||||||
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
|
||||||
|
|||||||
@@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session:
|
|||||||
* Timestamp
|
* Timestamp
|
||||||
* Episode number
|
* Episode number
|
||||||
* Step number
|
* Step number
|
||||||
* Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y
|
* Initial observation space (what the blue agent observed when it decided its action)
|
||||||
* Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y
|
|
||||||
* Reward value
|
* Reward value
|
||||||
* Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
* Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X
|
||||||
|
|
||||||
**Diagrams**
|
**Diagrams**
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,23 @@
|
|||||||
# "STABLE_BASELINES3_A2C"
|
# "STABLE_BASELINES3_A2C"
|
||||||
# "GENERIC"
|
# "GENERIC"
|
||||||
agent_identifier: STABLE_BASELINES3_A2C
|
agent_identifier: STABLE_BASELINES3_A2C
|
||||||
|
|
||||||
|
# RED AGENT IDENTIFIER
|
||||||
|
# RANDOM or NONE
|
||||||
|
random_red_agent: False
|
||||||
|
|
||||||
# Sets How the Action Space is defined:
|
# Sets How the Action Space is defined:
|
||||||
# "NODE"
|
# "NODE"
|
||||||
# "ACL"
|
# "ACL"
|
||||||
# "ANY" node and acl actions
|
# "ANY" node and acl actions
|
||||||
action_type: NODE
|
action_type: NODE
|
||||||
|
# observation space
|
||||||
|
observation_space:
|
||||||
|
# flatten: true
|
||||||
|
components:
|
||||||
|
- name: NODE_LINK_TABLE
|
||||||
|
# - name: NODE_STATUSES
|
||||||
|
# - name: LINK_TRAFFIC_LEVELS
|
||||||
# Number of episodes to run per session
|
# Number of episodes to run per session
|
||||||
num_episodes: 10
|
num_episodes: 10
|
||||||
# Number of time_steps per episode
|
# Number of time_steps per episode
|
||||||
|
|||||||
@@ -0,0 +1,99 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agent_identifier: STABLE_BASELINES3_A2C
|
||||||
|
|
||||||
|
# RED AGENT IDENTIFIER
|
||||||
|
# RANDOM or NONE
|
||||||
|
random_red_agent: True
|
||||||
|
|
||||||
|
# Sets How the Action Space is defined:
|
||||||
|
# "NODE"
|
||||||
|
# "ACL"
|
||||||
|
# "ANY" node and acl actions
|
||||||
|
action_type: NODE
|
||||||
|
# Number of episodes to run per session
|
||||||
|
num_episodes: 10
|
||||||
|
# Number of time_steps per episode
|
||||||
|
num_steps: 256
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
time_delay: 10
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
session_type: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
load_agent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observation_space_high_value: 1000000000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
all_ok: 0
|
||||||
|
# Node Hardware State
|
||||||
|
off_should_be_on: -10
|
||||||
|
off_should_be_resetting: -5
|
||||||
|
on_should_be_off: -2
|
||||||
|
on_should_be_resetting: -5
|
||||||
|
resetting_should_be_on: -5
|
||||||
|
resetting_should_be_off: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
good_should_be_patching: 2
|
||||||
|
good_should_be_compromised: 5
|
||||||
|
good_should_be_overwhelmed: 5
|
||||||
|
patching_should_be_good: -5
|
||||||
|
patching_should_be_compromised: 2
|
||||||
|
patching_should_be_overwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromised_should_be_good: -20
|
||||||
|
compromised_should_be_patching: -20
|
||||||
|
compromised_should_be_overwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmed_should_be_good: -20
|
||||||
|
overwhelmed_should_be_patching: -20
|
||||||
|
overwhelmed_should_be_compromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
good_should_be_repairing: 2
|
||||||
|
good_should_be_restoring: 2
|
||||||
|
good_should_be_corrupt: 5
|
||||||
|
good_should_be_destroyed: 10
|
||||||
|
repairing_should_be_good: -5
|
||||||
|
repairing_should_be_restoring: 2
|
||||||
|
repairing_should_be_corrupt: 2
|
||||||
|
repairing_should_be_destroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoring_should_be_good: -10
|
||||||
|
restoring_should_be_repairing: -2
|
||||||
|
restoring_should_be_corrupt: 1
|
||||||
|
restoring_should_be_destroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corrupt_should_be_good: -10
|
||||||
|
corrupt_should_be_repairing: -10
|
||||||
|
corrupt_should_be_restoring: -10
|
||||||
|
corrupt_should_be_destroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyed_should_be_good: -20
|
||||||
|
destroyed_should_be_repairing: -20
|
||||||
|
destroyed_should_be_restoring: -20
|
||||||
|
destroyed_should_be_corrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
red_ier_running: -5
|
||||||
|
green_ier_blocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
os_patching_duration: 5 # The time taken to patch the OS
|
||||||
|
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||||
|
service_patching_duration: 5 # The time taken to patch a service
|
||||||
|
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||||
|
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||||
|
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||||
@@ -21,6 +21,9 @@ class TrainingConfig:
|
|||||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||||
"The Red Agent algo/class to be used."
|
"The Red Agent algo/class to be used."
|
||||||
|
|
||||||
|
random_red_agent: bool = False
|
||||||
|
"Creates Random Red Agent Attacks"
|
||||||
|
|
||||||
action_type: ActionType = ActionType.ANY
|
action_type: ActionType = ActionType.ANY
|
||||||
"The ActionType to use."
|
"The ActionType to use."
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC):
|
|||||||
self.env: "Primaite" = env
|
self.env: "Primaite" = env
|
||||||
self.space: spaces.Space
|
self.space: spaces.Space
|
||||||
self.current_observation: np.ndarray # type might be too restrictive?
|
self.current_observation: np.ndarray # type might be too restrictive?
|
||||||
|
self.structure: List[str]
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC):
|
|||||||
"""Update the observation based on the current state of the environment."""
|
"""Update the observation based on the current state of the environment."""
|
||||||
self.current_observation = NotImplemented
|
self.current_observation = NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_structure(self) -> List[str]:
|
||||||
|
"""Return a list of labels for the components of the flattened observation space."""
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
class NodeLinkTable(AbstractObservationComponent):
|
class NodeLinkTable(AbstractObservationComponent):
|
||||||
"""Table with nodes and links as rows and hardware/software status as cols.
|
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||||
@@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
|||||||
# 3. Initialise Observation with zeroes
|
# 3. Initialise Observation with zeroes
|
||||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||||
|
|
||||||
|
self.structure = self.generate_structure()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the observation based on current environment state.
|
"""Update the observation based on current environment state.
|
||||||
|
|
||||||
@@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent):
|
|||||||
protocol_index += 1
|
protocol_index += 1
|
||||||
item_index += 1
|
item_index += 1
|
||||||
|
|
||||||
|
def generate_structure(self):
|
||||||
|
"""Return a list of labels for the components of the flattened observation space."""
|
||||||
|
nodes = self.env.nodes.values()
|
||||||
|
links = self.env.links.values()
|
||||||
|
|
||||||
|
structure = []
|
||||||
|
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
node_id = node.node_id
|
||||||
|
node_labels = [
|
||||||
|
f"node_{node_id}_id",
|
||||||
|
f"node_{node_id}_hardware_status",
|
||||||
|
f"node_{node_id}_os_status",
|
||||||
|
f"node_{node_id}_fs_status",
|
||||||
|
]
|
||||||
|
for j, serv in enumerate(self.env.services_list):
|
||||||
|
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||||
|
|
||||||
|
structure.extend(node_labels)
|
||||||
|
|
||||||
|
for i, link in enumerate(links):
|
||||||
|
link_id = link.id
|
||||||
|
link_labels = [
|
||||||
|
f"link_{link_id}_id",
|
||||||
|
f"link_{link_id}_n/a",
|
||||||
|
f"link_{link_id}_n/a",
|
||||||
|
f"link_{link_id}_n/a",
|
||||||
|
]
|
||||||
|
for j, serv in enumerate(self.env.services_list):
|
||||||
|
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||||
|
|
||||||
|
structure.extend(link_labels)
|
||||||
|
return structure
|
||||||
|
|
||||||
|
|
||||||
class NodeStatuses(AbstractObservationComponent):
|
class NodeStatuses(AbstractObservationComponent):
|
||||||
"""Flat list of nodes' hardware, OS, file system, and service states.
|
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||||
@@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent):
|
|||||||
|
|
||||||
# 3. Initialise observation with zeroes
|
# 3. Initialise observation with zeroes
|
||||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||||
|
self.structure = self.generate_structure()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the observation based on current environment state.
|
"""Update the observation based on current environment state.
|
||||||
@@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent):
|
|||||||
)
|
)
|
||||||
self.current_observation[:] = obs
|
self.current_observation[:] = obs
|
||||||
|
|
||||||
|
def generate_structure(self):
|
||||||
|
"""Return a list of labels for the components of the flattened observation space."""
|
||||||
|
services = self.env.services_list
|
||||||
|
|
||||||
|
structure = []
|
||||||
|
for _, node in self.env.nodes.items():
|
||||||
|
node_id = node.node_id
|
||||||
|
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||||
|
for state in HardwareState:
|
||||||
|
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||||
|
structure.append(f"node_{node_id}_software_state_NONE")
|
||||||
|
for state in SoftwareState:
|
||||||
|
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||||
|
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||||
|
for state in FileSystemState:
|
||||||
|
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||||
|
for service in services:
|
||||||
|
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||||
|
for state in SoftwareState:
|
||||||
|
structure.append(
|
||||||
|
f"node_{node_id}_service_{service}_state_{state.name}"
|
||||||
|
)
|
||||||
|
return structure
|
||||||
|
|
||||||
|
|
||||||
class LinkTrafficLevels(AbstractObservationComponent):
|
class LinkTrafficLevels(AbstractObservationComponent):
|
||||||
"""Flat list of traffic levels encoded into banded categories.
|
"""Flat list of traffic levels encoded into banded categories.
|
||||||
@@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
|||||||
# 3. Initialise observation with zeroes
|
# 3. Initialise observation with zeroes
|
||||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||||
|
|
||||||
|
self.structure = self.generate_structure()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update the observation based on current environment state.
|
"""Update the observation based on current environment state.
|
||||||
|
|
||||||
@@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
|||||||
|
|
||||||
self.current_observation[:] = obs
|
self.current_observation[:] = obs
|
||||||
|
|
||||||
|
def generate_structure(self):
|
||||||
|
"""Return a list of labels for the components of the flattened observation space."""
|
||||||
|
structure = []
|
||||||
|
for _, link in self.env.links.items():
|
||||||
|
link_id = link.id
|
||||||
|
if self._combine_service_traffic:
|
||||||
|
protocols = ["overall"]
|
||||||
|
else:
|
||||||
|
protocols = [protocol.name for protocol in link.protocol_list]
|
||||||
|
|
||||||
|
for p in protocols:
|
||||||
|
for i in range(self._quantisation_levels):
|
||||||
|
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||||
|
return structure
|
||||||
|
|
||||||
|
|
||||||
class ObservationsHandler:
|
class ObservationsHandler:
|
||||||
"""Component-based observation space handler.
|
"""Component-based observation space handler.
|
||||||
@@ -311,8 +395,17 @@ class ObservationsHandler:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||||
self.space: spaces.Space
|
|
||||||
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
# internal the observation space (unflattened version of space if flatten=True)
|
||||||
|
self._space: spaces.Space
|
||||||
|
# flattened version of the observation space
|
||||||
|
self._flat_space: spaces.Space
|
||||||
|
|
||||||
|
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||||
|
# used for transactions and when flatten=true
|
||||||
|
self._flat_observation: np.ndarray
|
||||||
|
|
||||||
|
self.flatten: bool = False
|
||||||
|
|
||||||
def update_obs(self):
|
def update_obs(self):
|
||||||
"""Fetch fresh information about the environment."""
|
"""Fetch fresh information about the environment."""
|
||||||
@@ -321,12 +414,11 @@ class ObservationsHandler:
|
|||||||
obs.update()
|
obs.update()
|
||||||
current_obs.append(obs.current_observation)
|
current_obs.append(obs.current_observation)
|
||||||
|
|
||||||
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
|
||||||
if len(current_obs) == 1:
|
if len(current_obs) == 1:
|
||||||
self.current_observation = current_obs[0]
|
self._observation = current_obs[0]
|
||||||
else:
|
else:
|
||||||
self.current_observation = tuple(current_obs)
|
self._observation = tuple(current_obs)
|
||||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||||
|
|
||||||
def register(self, obs_component: AbstractObservationComponent):
|
def register(self, obs_component: AbstractObservationComponent):
|
||||||
"""Add a component for this handler to track.
|
"""Add a component for this handler to track.
|
||||||
@@ -353,12 +445,31 @@ class ObservationsHandler:
|
|||||||
for obs_comp in self.registered_obs_components:
|
for obs_comp in self.registered_obs_components:
|
||||||
component_spaces.append(obs_comp.space)
|
component_spaces.append(obs_comp.space)
|
||||||
|
|
||||||
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
# if there are multiple components, build a composite tuple space
|
||||||
if len(component_spaces) == 1:
|
if len(component_spaces) == 1:
|
||||||
self.space = component_spaces[0]
|
self._space = component_spaces[0]
|
||||||
else:
|
else:
|
||||||
self.space = spaces.Tuple(component_spaces)
|
self._space = spaces.Tuple(component_spaces)
|
||||||
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
if len(component_spaces) > 0:
|
||||||
|
self._flat_space = spaces.flatten_space(self._space)
|
||||||
|
else:
|
||||||
|
self._flat_space = spaces.Box(0, 1, (0,))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def space(self):
|
||||||
|
"""Observation space, return the flattened version if flatten is True."""
|
||||||
|
if self.flatten:
|
||||||
|
return self._flat_space
|
||||||
|
else:
|
||||||
|
return self._space
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_observation(self):
|
||||||
|
"""Current observation, return the flattened version if flatten is True."""
|
||||||
|
if self.flatten:
|
||||||
|
return self._flat_observation
|
||||||
|
else:
|
||||||
|
return self._observation
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||||
@@ -388,6 +499,9 @@ class ObservationsHandler:
|
|||||||
# Instantiate the handler
|
# Instantiate the handler
|
||||||
handler = cls()
|
handler = cls()
|
||||||
|
|
||||||
|
if obs_space_config.get("flatten"):
|
||||||
|
handler.flatten = True
|
||||||
|
|
||||||
for component_cfg in obs_space_config["components"]:
|
for component_cfg in obs_space_config["components"]:
|
||||||
# Figure out which class can instantiate the desired component
|
# Figure out which class can instantiate the desired component
|
||||||
comp_type = component_cfg["name"]
|
comp_type = component_cfg["name"]
|
||||||
@@ -401,3 +515,17 @@ class ObservationsHandler:
|
|||||||
|
|
||||||
handler.update_obs()
|
handler.update_obs()
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
def describe_structure(self):
|
||||||
|
"""Create a list of names for the features of the obs space.
|
||||||
|
|
||||||
|
The order of labels follows the flattened version of the space.
|
||||||
|
"""
|
||||||
|
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||||
|
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||||
|
|
||||||
|
labels = []
|
||||||
|
for obs_comp in self.registered_obs_components:
|
||||||
|
labels.extend(obs_comp.structure)
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
import copy
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
import uuid as uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from random import choice, randint, sample, uniform
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -197,7 +199,6 @@ class Primaite(Env):
|
|||||||
try:
|
try:
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
nx.draw_networkx(self.network, with_labels=True)
|
nx.draw_networkx(self.network, with_labels=True)
|
||||||
# now = datetime.now() # current date and time
|
|
||||||
|
|
||||||
file_path = session_path / f"network_{timestamp_str}.png"
|
file_path = session_path / f"network_{timestamp_str}.png"
|
||||||
plt.savefig(file_path, format="PNG")
|
plt.savefig(file_path, format="PNG")
|
||||||
@@ -281,6 +282,10 @@ class Primaite(Env):
|
|||||||
# Does this for both live and reference nodes
|
# Does this for both live and reference nodes
|
||||||
self.reset_environment()
|
self.reset_environment()
|
||||||
|
|
||||||
|
# Create a random red agent to use for this episode
|
||||||
|
if self.training_config.random_red_agent:
|
||||||
|
self._create_random_red_agent()
|
||||||
|
|
||||||
# Reset counters and totals
|
# Reset counters and totals
|
||||||
self.total_reward = 0
|
self.total_reward = 0
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
@@ -325,7 +330,8 @@ class Primaite(Env):
|
|||||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||||
)
|
)
|
||||||
# Load the initial observation space into the transaction
|
# Load the initial observation space into the transaction
|
||||||
transaction.set_obs_space_pre(copy.deepcopy(self.env_obs))
|
transaction.set_obs_space(self.obs_handler._flat_observation)
|
||||||
|
|
||||||
# Load the action space into the transaction
|
# Load the action space into the transaction
|
||||||
transaction.set_action_space(copy.deepcopy(action))
|
transaction.set_action_space(copy.deepcopy(action))
|
||||||
|
|
||||||
@@ -406,8 +412,6 @@ class Primaite(Env):
|
|||||||
|
|
||||||
# 7. Update env_obs
|
# 7. Update env_obs
|
||||||
self.update_environent_obs()
|
self.update_environent_obs()
|
||||||
# Load the new observation space into the transaction
|
|
||||||
transaction.set_obs_space_post(copy.deepcopy(self.env_obs))
|
|
||||||
|
|
||||||
# 8. Add the transaction to the list of transactions
|
# 8. Add the transaction to the list of transactions
|
||||||
self.transaction_list.append(copy.deepcopy(transaction))
|
self.transaction_list.append(copy.deepcopy(transaction))
|
||||||
@@ -1240,3 +1244,136 @@ class Primaite(Env):
|
|||||||
# Combine the Node dict and ACL dict
|
# Combine the Node dict and ACL dict
|
||||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||||
return combined_action_dict
|
return combined_action_dict
|
||||||
|
|
||||||
|
def _create_random_red_agent(self):
|
||||||
|
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||||
|
# Reset the current red iers and red node pol
|
||||||
|
self.red_iers = {}
|
||||||
|
self.red_node_pol = {}
|
||||||
|
|
||||||
|
# Decide how many nodes become compromised
|
||||||
|
node_list = list(self.nodes.values())
|
||||||
|
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
||||||
|
max_num_nodes_compromised = len(
|
||||||
|
computers
|
||||||
|
) # only computers can become compromised
|
||||||
|
# random select between 1 and max_num_nodes_compromised
|
||||||
|
num_nodes_to_compromise = randint(1, max_num_nodes_compromised)
|
||||||
|
|
||||||
|
# Decide which of the nodes to compromise
|
||||||
|
nodes_to_be_compromised = sample(computers, num_nodes_to_compromise)
|
||||||
|
|
||||||
|
# choose a random compromise node to be source of attacks
|
||||||
|
source_node = choice(nodes_to_be_compromised)
|
||||||
|
|
||||||
|
# For each of the nodes to be compromised decide which step they become compromised
|
||||||
|
max_step_compromised = (
|
||||||
|
self.episode_steps // 2
|
||||||
|
) # always compromise in first half of episode
|
||||||
|
|
||||||
|
# Bandwidth for all links
|
||||||
|
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
||||||
|
|
||||||
|
if len(bandwidths) < 1:
|
||||||
|
msg = "Random red agent cannot be used on a network without any links"
|
||||||
|
_LOGGER.error(msg)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
|
||||||
|
|
||||||
|
for n, node in enumerate(nodes_to_be_compromised):
|
||||||
|
# 1: Use Node PoL to set node to compromised
|
||||||
|
|
||||||
|
_id = str(uuid.uuid4())
|
||||||
|
_start_step = randint(2, max_step_compromised + 1) # step compromised
|
||||||
|
pol_service_name = choice(list(node.services.keys()))
|
||||||
|
|
||||||
|
source_node_service = choice(list(source_node.services.values()))
|
||||||
|
|
||||||
|
red_pol = NodeStateInstructionRed(
|
||||||
|
_id=_id,
|
||||||
|
_start_step=_start_step,
|
||||||
|
_end_step=_start_step, # only run for 1 step
|
||||||
|
_target_node_id=node.node_id,
|
||||||
|
_pol_initiator="DIRECT",
|
||||||
|
_pol_type=NodePOLType["SERVICE"],
|
||||||
|
pol_protocol=pol_service_name,
|
||||||
|
_pol_state=SoftwareState.COMPROMISED,
|
||||||
|
_pol_source_node_id=source_node.node_id,
|
||||||
|
_pol_source_node_service=source_node_service.name,
|
||||||
|
_pol_source_node_service_state=source_node_service.software_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.red_node_pol[_id] = red_pol
|
||||||
|
|
||||||
|
# 2: Launch the attack from compromised node - set the IER
|
||||||
|
|
||||||
|
ier_id = str(uuid.uuid4())
|
||||||
|
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||||
|
ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8))
|
||||||
|
ier_end_step = self.episode_steps
|
||||||
|
|
||||||
|
# Randomise the load, as a percentage of a random link bandwith
|
||||||
|
ier_load = uniform(0.4, 0.8) * choice(bandwidths)
|
||||||
|
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||||
|
ier_service = node.services[pol_service_name]
|
||||||
|
ier_port = ier_service.port
|
||||||
|
ier_mission_criticality = (
|
||||||
|
0 # Red IER will never be important to green agent success
|
||||||
|
)
|
||||||
|
# We choose a node to attack based on the first that applies:
|
||||||
|
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
||||||
|
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
||||||
|
# c. Attack a random server
|
||||||
|
possible_ier_destinations = [
|
||||||
|
ier.get_dest_node_id()
|
||||||
|
for ier in list(self.green_iers.values())
|
||||||
|
if ier.get_source_node_id() == node.node_id
|
||||||
|
]
|
||||||
|
if len(possible_ier_destinations) < 1:
|
||||||
|
for server in servers:
|
||||||
|
if not self.acl.is_blocked(
|
||||||
|
node.ip_address,
|
||||||
|
server.ip_address,
|
||||||
|
ier_service,
|
||||||
|
ier_port,
|
||||||
|
):
|
||||||
|
possible_ier_destinations.append(server.node_id)
|
||||||
|
if len(possible_ier_destinations) < 1:
|
||||||
|
# If still none found choose from all servers
|
||||||
|
possible_ier_destinations = [server.node_id for server in servers]
|
||||||
|
ier_dest = choice(possible_ier_destinations)
|
||||||
|
self.red_iers[ier_id] = IER(
|
||||||
|
ier_id,
|
||||||
|
ier_start_step,
|
||||||
|
ier_end_step,
|
||||||
|
ier_load,
|
||||||
|
ier_protocol,
|
||||||
|
ier_port,
|
||||||
|
node.node_id,
|
||||||
|
ier_dest,
|
||||||
|
ier_mission_criticality,
|
||||||
|
)
|
||||||
|
|
||||||
|
overwhelm_pol = red_pol
|
||||||
|
overwhelm_pol.id = str(uuid.uuid4())
|
||||||
|
overwhelm_pol.end_step = self.episode_steps
|
||||||
|
|
||||||
|
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
||||||
|
# # TODO remove duplicate red pol for same targetted service - must take into account start step
|
||||||
|
|
||||||
|
o_pol_id = str(uuid.uuid4())
|
||||||
|
o_red_pol = NodeStateInstructionRed(
|
||||||
|
_id=o_pol_id,
|
||||||
|
_start_step=ier_start_step,
|
||||||
|
_end_step=self.episode_steps,
|
||||||
|
_target_node_id=ier_dest,
|
||||||
|
_pol_initiator="DIRECT",
|
||||||
|
_pol_type=NodePOLType["SERVICE"],
|
||||||
|
pol_protocol=ier_protocol,
|
||||||
|
_pol_state=SoftwareState.OVERWHELMED,
|
||||||
|
_pol_source_node_id=source_node.node_id,
|
||||||
|
_pol_source_node_service=source_node_service.name,
|
||||||
|
_pol_source_node_service_state=source_node_service.software_state,
|
||||||
|
)
|
||||||
|
self.red_node_pol[o_pol_id] = o_red_pol
|
||||||
|
|||||||
@@ -78,8 +78,8 @@ def calculate_reward_function(
|
|||||||
start_step = ier_value.get_start_step()
|
start_step = ier_value.get_start_step()
|
||||||
stop_step = ier_value.get_end_step()
|
stop_step = ier_value.get_end_step()
|
||||||
if step_count >= start_step and step_count <= stop_step:
|
if step_count >= start_step and step_count <= stop_step:
|
||||||
reference_blocked = reference_ier.get_is_running()
|
reference_blocked = not reference_ier.get_is_running()
|
||||||
live_blocked = ier_value.get_is_running()
|
live_blocked = not ier_value.get_is_running()
|
||||||
ier_reward = (
|
ier_reward = (
|
||||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -354,6 +354,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str,
|
|||||||
transaction_list=transaction_list,
|
transaction_list=transaction_list,
|
||||||
session_path=session_dir,
|
session_path=session_dir,
|
||||||
timestamp_str=timestamp_str,
|
timestamp_str=timestamp_str,
|
||||||
|
obs_space_description=env.obs_handler.describe_structure(),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Updating Session Metadata file...")
|
print("Updating Session Metadata file...")
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
"""Defines node behaviour for Green PoL."""
|
"""Defines node behaviour for Green PoL."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from primaite.common.enums import NodePOLType
|
from primaite.common.enums import NodePOLType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
class NodeStateInstructionRed(object):
|
class NodeStateInstructionRed(object):
|
||||||
"""The Node State Instruction class."""
|
"""The Node State Instruction class."""
|
||||||
|
|
||||||
|
|||||||
@@ -190,14 +190,14 @@ class ServiceNode(ActiveNode):
|
|||||||
service_value.reduce_patching_count()
|
service_value.reduce_patching_count()
|
||||||
|
|
||||||
def update_resetting_status(self):
|
def update_resetting_status(self):
|
||||||
"""Updates the resetting counter for any service that are resetting."""
|
"""Update resetting counter and set software state if it reached 0."""
|
||||||
super().update_resetting_status()
|
super().update_resetting_status()
|
||||||
if self.resetting_count <= 0:
|
if self.resetting_count <= 0:
|
||||||
for service in self.services.values():
|
for service in self.services.values():
|
||||||
service.software_state = SoftwareState.GOOD
|
service.software_state = SoftwareState.GOOD
|
||||||
|
|
||||||
def update_booting_status(self):
|
def update_booting_status(self):
|
||||||
"""Updates the booting counter for any service that are booting up."""
|
"""Update booting counter and set software to good if it reached 0."""
|
||||||
super().update_booting_status()
|
super().update_booting_status()
|
||||||
if self.booting_count <= 0:
|
if self.booting_count <= 0:
|
||||||
for service in self.services.values():
|
for service in self.services.values():
|
||||||
|
|||||||
@@ -20,23 +20,14 @@ class Transaction(object):
|
|||||||
self.episode_number = _episode_number
|
self.episode_number = _episode_number
|
||||||
self.step_number = _step_number
|
self.step_number = _step_number
|
||||||
|
|
||||||
def set_obs_space_pre(self, _obs_space_pre):
|
def set_obs_space(self, _obs_space):
|
||||||
"""
|
"""
|
||||||
Sets the observation space (pre).
|
Sets the observation space (pre).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_obs_space_pre: The observation space before any actions are taken
|
_obs_space_pre: The observation space before any actions are taken
|
||||||
"""
|
"""
|
||||||
self.obs_space_pre = _obs_space_pre
|
self.obs_space = _obs_space
|
||||||
|
|
||||||
def set_obs_space_post(self, _obs_space_post):
|
|
||||||
"""
|
|
||||||
Sets the observation space (post).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_obs_space_post: The observation space after any actions are taken
|
|
||||||
"""
|
|
||||||
self.obs_space_post = _obs_space_post
|
|
||||||
|
|
||||||
def set_reward(self, _reward):
|
def set_reward(self, _reward):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space):
|
|||||||
return [str(_action_space)]
|
return [str(_action_space)]
|
||||||
|
|
||||||
|
|
||||||
def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features):
|
def write_transaction_to_file(
|
||||||
"""
|
transaction_list,
|
||||||
Turns observation space into a string array so it can be saved to csv.
|
session_path: Path,
|
||||||
|
timestamp_str: str,
|
||||||
Args:
|
obs_space_description: list,
|
||||||
_obs_space: The observation space
|
):
|
||||||
_obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
|
||||||
_obs_features: The number of features associated with the asset
|
|
||||||
"""
|
|
||||||
return_array = []
|
|
||||||
for x in range(_obs_assets):
|
|
||||||
for y in range(_obs_features):
|
|
||||||
return_array.append(str(_obs_space[x][y]))
|
|
||||||
|
|
||||||
return return_array
|
|
||||||
|
|
||||||
|
|
||||||
def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str):
|
|
||||||
"""
|
"""
|
||||||
Writes transaction logs to file to support training evaluation.
|
Writes transaction logs to file to support training evaluation.
|
||||||
|
|
||||||
@@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
|||||||
# This will be tied into the PrimAITE Use Case so that they make sense
|
# This will be tied into the PrimAITE Use Case so that they make sense
|
||||||
template_transation = transaction_list[0]
|
template_transation = transaction_list[0]
|
||||||
action_length = template_transation.action_space.size
|
action_length = template_transation.action_space.size
|
||||||
obs_shape = template_transation.obs_space_post.shape
|
# obs_shape = template_transation.obs_space_post.shape
|
||||||
obs_assets = template_transation.obs_space_post.shape[0]
|
# obs_assets = template_transation.obs_space_post.shape[0]
|
||||||
if len(obs_shape) == 1:
|
# if len(obs_shape) == 1:
|
||||||
# bit of a workaround but I think the way transactions are written will change soon
|
# bit of a workaround but I think the way transactions are written will change soon
|
||||||
obs_features = 1
|
# obs_features = 1
|
||||||
else:
|
# else:
|
||||||
obs_features = template_transation.obs_space_post.shape[1]
|
# obs_features = template_transation.obs_space_post.shape[1]
|
||||||
|
|
||||||
# Create the action space headers array
|
# Create the action space headers array
|
||||||
action_header = []
|
action_header = []
|
||||||
@@ -70,16 +58,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
|||||||
action_header.append("AS_" + str(x))
|
action_header.append("AS_" + str(x))
|
||||||
|
|
||||||
# Create the observation space headers array
|
# Create the observation space headers array
|
||||||
obs_header_initial = []
|
# obs_header_initial = [f"pre_{o}" for o in obs_space_description]
|
||||||
obs_header_new = []
|
# obs_header_new = [f"post_{o}" for o in obs_space_description]
|
||||||
for x in range(obs_assets):
|
|
||||||
for y in range(obs_features):
|
|
||||||
obs_header_initial.append("OSI_" + str(x) + "_" + str(y))
|
|
||||||
obs_header_new.append("OSN_" + str(x) + "_" + str(y))
|
|
||||||
|
|
||||||
# Open up a csv file
|
# Open up a csv file
|
||||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||||
header = header + action_header + obs_header_initial + obs_header_new
|
header = header + action_header + obs_space_description
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
filename = session_path / f"all_transactions_{timestamp_str}.csv"
|
||||||
@@ -98,12 +82,7 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st
|
|||||||
csv_data = (
|
csv_data = (
|
||||||
csv_data
|
csv_data
|
||||||
+ turn_action_space_to_array(transaction.action_space)
|
+ turn_action_space_to_array(transaction.action_space)
|
||||||
+ turn_obs_space_to_array(
|
+ transaction.obs_space.tolist()
|
||||||
transaction.obs_space_pre, obs_assets, obs_features
|
|
||||||
)
|
|
||||||
+ turn_obs_space_to_array(
|
|
||||||
transaction.obs_space_post, obs_assets, obs_features
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
csv_writer.writerow(csv_data)
|
csv_writer.writerow(csv_data)
|
||||||
|
|
||||||
|
|||||||
77
tests/test_red_random_agent_behaviour.py
Normal file
77
tests/test_red_random_agent_behaviour.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from primaite.config.lay_down_config import data_manipulation_config_path
|
||||||
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||||
|
from tests import TEST_CONFIG_ROOT
|
||||||
|
from tests.conftest import _get_temp_session_path
|
||||||
|
|
||||||
|
|
||||||
|
def run_generic(env, config_values):
|
||||||
|
"""Run against a generic agent."""
|
||||||
|
# Reset the environment at the start of the episode
|
||||||
|
env.reset()
|
||||||
|
for episode in range(0, config_values.num_episodes):
|
||||||
|
for step in range(0, config_values.num_steps):
|
||||||
|
# Send the observation space to the agent to get an action
|
||||||
|
# TEMP - random action for now
|
||||||
|
# action = env.blue_agent_action(obs)
|
||||||
|
# action = env.action_space.sample()
|
||||||
|
action = 0
|
||||||
|
|
||||||
|
# Run the simulation step on the live environment
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Break if done is True
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset the environment at the end of the episode
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_red_agent_behaviour():
|
||||||
|
"""
|
||||||
|
Test that hardware state is penalised at each step.
|
||||||
|
|
||||||
|
When the initial state is OFF compared to reference state which is ON.
|
||||||
|
"""
|
||||||
|
list_of_node_instructions = []
|
||||||
|
|
||||||
|
# RUN TWICE so we can make sure that red agent is randomised
|
||||||
|
for i in range(2):
|
||||||
|
"""Takes a config path and returns the created instance of Primaite."""
|
||||||
|
session_timestamp: datetime = datetime.now()
|
||||||
|
session_path = _get_temp_session_path(session_timestamp)
|
||||||
|
|
||||||
|
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
env = Primaite(
|
||||||
|
training_config_path=TEST_CONFIG_ROOT
|
||||||
|
/ "one_node_states_on_off_main_config.yaml",
|
||||||
|
lay_down_config_path=data_manipulation_config_path(),
|
||||||
|
transaction_list=[],
|
||||||
|
session_path=session_path,
|
||||||
|
timestamp_str=timestamp_str,
|
||||||
|
)
|
||||||
|
# set red_agent_
|
||||||
|
env.training_config.random_red_agent = True
|
||||||
|
training_config = env.training_config
|
||||||
|
training_config.num_steps = env.episode_steps
|
||||||
|
|
||||||
|
run_generic(env, training_config)
|
||||||
|
# add red pol instructions to list
|
||||||
|
list_of_node_instructions.append(env.red_node_pol)
|
||||||
|
|
||||||
|
# compare instructions to make sure that red instructions are truly random
|
||||||
|
for index, instruction in enumerate(list_of_node_instructions):
|
||||||
|
for key in list_of_node_instructions[index].keys():
|
||||||
|
instruction: NodeStateInstructionRed = list_of_node_instructions[index][key]
|
||||||
|
print(f"run {index}")
|
||||||
|
print(f"{key} start step: {instruction.get_start_step()}")
|
||||||
|
print(f"{key} end step: {instruction.get_end_step()}")
|
||||||
|
print(f"{key} target node id: {instruction.get_target_node_id()}")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])
|
||||||
@@ -16,17 +16,26 @@ def test_rewards_are_being_penalised_at_each_step_function():
|
|||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
On different steps (of the 13 in total) these are the following rewards for config_6 which are activated:
|
The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps:
|
||||||
File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3)
|
On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute.
|
||||||
Hardware State: onShouldBeOff = -2 (between Steps 4 & 6)
|
For example, turning the nodes' file system state to CORRUPT from its original state GOOD.
|
||||||
Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9)
|
As a result these are the following rewards are activated:
|
||||||
Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12)
|
File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 & 2)
|
||||||
|
Hardware State: off_should_be_on = -10 * 2 (on Steps 4 & 5)
|
||||||
|
Service State: compromised_should_be_good = -20 * 2 (on Steps 7 & 8)
|
||||||
|
Software State: compromised_should_be_good = -20 * 2 (on Steps 10 & 11)
|
||||||
|
|
||||||
Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26
|
The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice.
|
||||||
Step Count: 13
|
|
||||||
|
Note: This test run inherits from conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded
|
||||||
|
to do NOTHING on every step.
|
||||||
|
We use Pattern of Lifes (PoLs) to change the nodes states and display that the agent is being penalised on all steps
|
||||||
|
where the live network node differs from the network reference node.
|
||||||
|
|
||||||
|
Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120
|
||||||
|
Step Count: 15
|
||||||
|
|
||||||
For the 4 steps where this occurs the average reward is:
|
For the 4 steps where this occurs the average reward is:
|
||||||
Average Reward: 2 (26 / 13)
|
Average Reward: -8 (-120 / 15)
|
||||||
"""
|
"""
|
||||||
print("average reward", env.average_reward)
|
|
||||||
assert env.average_reward == -8.0
|
assert env.average_reward == -8.0
|
||||||
|
|||||||
Reference in New Issue
Block a user