feature\1522:
Create random red agent behaviour.
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Final, Union, Optional
|
from typing import Any, Dict, Final, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@@ -21,6 +21,9 @@ class TrainingConfig:
|
|||||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||||
"The Red Agent algo/class to be used."
|
"The Red Agent algo/class to be used."
|
||||||
|
|
||||||
|
red_agent_identifier: str = "RANDOM"
|
||||||
|
"Creates Random Red Agent Attacks"
|
||||||
|
|
||||||
action_type: ActionType = ActionType.ANY
|
action_type: ActionType = ActionType.ANY
|
||||||
"The ActionType to use."
|
"The ActionType to use."
|
||||||
|
|
||||||
@@ -167,8 +170,7 @@ def main_training_config_path() -> Path:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def load(file_path: Union[str, Path],
|
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||||
legacy_file: bool = False) -> TrainingConfig:
|
|
||||||
"""
|
"""
|
||||||
Read in a training config yaml file.
|
Read in a training config yaml file.
|
||||||
|
|
||||||
@@ -213,9 +215,7 @@ def load(file_path: Union[str, Path],
|
|||||||
|
|
||||||
|
|
||||||
def convert_legacy_training_config_dict(
|
def convert_legacy_training_config_dict(
|
||||||
legacy_config_dict: Dict[str, Any],
|
legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY"
|
||||||
num_steps: int = 256,
|
|
||||||
action_type: str = "ANY"
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert a legacy training config dict to the new format.
|
Convert a legacy training config dict to the new format.
|
||||||
@@ -227,10 +227,7 @@ def convert_legacy_training_config_dict(
|
|||||||
don't have action_type values.
|
don't have action_type values.
|
||||||
:return: The converted training config dict.
|
:return: The converted training config dict.
|
||||||
"""
|
"""
|
||||||
config_dict = {
|
config_dict = {"num_steps": num_steps, "action_type": action_type}
|
||||||
"num_steps": num_steps,
|
|
||||||
"action_type": action_type
|
|
||||||
}
|
|
||||||
for legacy_key, value in legacy_config_dict.items():
|
for legacy_key, value in legacy_config_dict.items():
|
||||||
new_key = _get_new_key_from_legacy(legacy_key)
|
new_key = _get_new_key_from_legacy(legacy_key)
|
||||||
if new_key:
|
if new_key:
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ from gym import Env, spaces
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from primaite.acl.access_control_list import AccessControlList
|
from primaite.acl.access_control_list import AccessControlList
|
||||||
from primaite.agents.utils import is_valid_acl_action_extra, \
|
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||||
is_valid_node_action
|
|
||||||
from primaite.common.custom_typing import NodeUnion
|
from primaite.common.custom_typing import NodeUnion
|
||||||
from primaite.common.enums import (
|
from primaite.common.enums import (
|
||||||
ActionType,
|
ActionType,
|
||||||
@@ -24,8 +23,9 @@ from primaite.common.enums import (
|
|||||||
NodePOLInitiator,
|
NodePOLInitiator,
|
||||||
NodePOLType,
|
NodePOLType,
|
||||||
NodeType,
|
NodeType,
|
||||||
|
ObservationType,
|
||||||
Priority,
|
Priority,
|
||||||
SoftwareState, ObservationType,
|
SoftwareState,
|
||||||
)
|
)
|
||||||
from primaite.common.service import Service
|
from primaite.common.service import Service
|
||||||
from primaite.config import training_config
|
from primaite.config import training_config
|
||||||
@@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function
|
|||||||
from primaite.links.link import Link
|
from primaite.links.link import Link
|
||||||
from primaite.nodes.active_node import ActiveNode
|
from primaite.nodes.active_node import ActiveNode
|
||||||
from primaite.nodes.node import Node
|
from primaite.nodes.node import Node
|
||||||
from primaite.nodes.node_state_instruction_green import \
|
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||||
NodeStateInstructionGreen
|
|
||||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||||
from primaite.nodes.passive_node import PassiveNode
|
from primaite.nodes.passive_node import PassiveNode
|
||||||
from primaite.nodes.service_node import ServiceNode
|
from primaite.nodes.service_node import ServiceNode
|
||||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||||
from primaite.pol.ier import IER
|
from primaite.pol.ier import IER
|
||||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, \
|
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||||
apply_red_agent_node_pol
|
|
||||||
from primaite.transactions.transaction import Transaction
|
from primaite.transactions.transaction import Transaction
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@@ -177,7 +175,6 @@ class Primaite(Env):
|
|||||||
# It will be initialised later.
|
# It will be initialised later.
|
||||||
self.obs_handler: ObservationsHandler
|
self.obs_handler: ObservationsHandler
|
||||||
|
|
||||||
|
|
||||||
# Open the config file and build the environment laydown
|
# Open the config file and build the environment laydown
|
||||||
with open(self._lay_down_config_path, "r") as file:
|
with open(self._lay_down_config_path, "r") as file:
|
||||||
# Open the config file and build the environment laydown
|
# Open the config file and build the environment laydown
|
||||||
@@ -238,7 +235,9 @@ class Primaite(Env):
|
|||||||
self.action_dict = self.create_node_and_acl_action_dict()
|
self.action_dict = self.create_node_and_acl_action_dict()
|
||||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||||
else:
|
else:
|
||||||
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
_LOGGER.info(
|
||||||
|
f"Invalid action type selected: {self.training_config.action_type}"
|
||||||
|
)
|
||||||
# Set up a csv to store the results of the training
|
# Set up a csv to store the results of the training
|
||||||
try:
|
try:
|
||||||
header = ["Episode", "Average Reward"]
|
header = ["Episode", "Average Reward"]
|
||||||
@@ -275,6 +274,10 @@ class Primaite(Env):
|
|||||||
# Does this for both live and reference nodes
|
# Does this for both live and reference nodes
|
||||||
self.reset_environment()
|
self.reset_environment()
|
||||||
|
|
||||||
|
# Create a random red agent to use for this episode
|
||||||
|
if self.training_config.red_agent_identifier == "RANDOM":
|
||||||
|
self.create_random_red_agent()
|
||||||
|
|
||||||
# Reset counters and totals
|
# Reset counters and totals
|
||||||
self.total_reward = 0
|
self.total_reward = 0
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
@@ -379,7 +382,7 @@ class Primaite(Env):
|
|||||||
self.step_count,
|
self.step_count,
|
||||||
self.training_config,
|
self.training_config,
|
||||||
)
|
)
|
||||||
#print(f" Step {self.step_count} Reward: {str(reward)}")
|
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||||
self.total_reward += reward
|
self.total_reward += reward
|
||||||
if self.step_count == self.episode_steps:
|
if self.step_count == self.episode_steps:
|
||||||
self.average_reward = self.total_reward / self.step_count
|
self.average_reward = self.total_reward / self.step_count
|
||||||
@@ -1033,7 +1036,6 @@ class Primaite(Env):
|
|||||||
"""
|
"""
|
||||||
self.observation_type = ObservationType[observation_info["type"]]
|
self.observation_type = ObservationType[observation_info["type"]]
|
||||||
|
|
||||||
|
|
||||||
def get_action_info(self, action_info):
|
def get_action_info(self, action_info):
|
||||||
"""
|
"""
|
||||||
Extracts action_info.
|
Extracts action_info.
|
||||||
@@ -1216,3 +1218,152 @@ class Primaite(Env):
|
|||||||
# Combine the Node dict and ACL dict
|
# Combine the Node dict and ACL dict
|
||||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||||
return combined_action_dict
|
return combined_action_dict
|
||||||
|
|
||||||
|
def create_random_red_agent(self):
|
||||||
|
"""Decide on random red agent for the episode to be called in env.reset()."""
|
||||||
|
|
||||||
|
# Reset the current red iers and red node pol
|
||||||
|
self.red_iers = {}
|
||||||
|
self.red_node_pol = {}
|
||||||
|
|
||||||
|
# Decide how many nodes become compromised
|
||||||
|
node_list = list(self.nodes.values())
|
||||||
|
computers = [node for node in node_list if node.node_type == NodeType.COMPUTER]
|
||||||
|
max_num_nodes_compromised = len(
|
||||||
|
computers
|
||||||
|
) # only computers can become compromised
|
||||||
|
# random select between 1 and max_num_nodes_compromised
|
||||||
|
num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1)
|
||||||
|
|
||||||
|
# Decide which of the nodes to compromise
|
||||||
|
nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise)
|
||||||
|
|
||||||
|
# For each of the nodes to be compromised decide which step they become compromised
|
||||||
|
max_step_compromised = (
|
||||||
|
self.episode_steps // 2
|
||||||
|
) # always compromise in first half of episode
|
||||||
|
|
||||||
|
# Bandwidth for all links
|
||||||
|
bandwidths = [i.get_bandwidth() for i in list(self.links.values())]
|
||||||
|
servers = [node for node in node_list if node.node_type == NodeType.SERVER]
|
||||||
|
|
||||||
|
for n, node in enumerate(nodes_to_be_compromised):
|
||||||
|
# 1: Use Node PoL to set node to compromised
|
||||||
|
|
||||||
|
_id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate
|
||||||
|
_start_step = np.random.randint(
|
||||||
|
2, max_step_compromised + 1
|
||||||
|
) # step compromised
|
||||||
|
_end_step = _start_step # Become compromised on 1 step
|
||||||
|
_target_node_id = node.node_id
|
||||||
|
_pol_initiator = "DIRECT"
|
||||||
|
_pol_type = NodePOLType["SERVICE"] # All computers are service nodes
|
||||||
|
pol_service_name = np.random.choice(
|
||||||
|
list(node.get_services().keys())
|
||||||
|
) # Random service may wish to change this, currently always TCP)
|
||||||
|
pol_protocol = pol_protocol
|
||||||
|
_pol_state = SoftwareState.COMPROMISED
|
||||||
|
is_entry_node = True # Assumes all computers in network are entry nodes
|
||||||
|
_pol_source_node_id = _pol_source_node_id
|
||||||
|
_pol_source_node_service = _pol_source_node_service
|
||||||
|
_pol_source_node_service_state = _pol_source_node_service_state
|
||||||
|
red_pol = NodeStateInstructionRed(
|
||||||
|
_id,
|
||||||
|
_start_step,
|
||||||
|
_end_step,
|
||||||
|
_target_node_id,
|
||||||
|
_pol_initiator,
|
||||||
|
_pol_type,
|
||||||
|
pol_protocol,
|
||||||
|
_pol_state,
|
||||||
|
_pol_source_node_id,
|
||||||
|
_pol_source_node_service,
|
||||||
|
_pol_source_node_service_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.red_node_pol[_id] = red_pol
|
||||||
|
|
||||||
|
# 2: Launch the attack from compromised node - set the IER
|
||||||
|
|
||||||
|
ier_id = str(2000 + n)
|
||||||
|
# Launch the attack after node is compromised, and not right at the end of the episode
|
||||||
|
ier_start_step = np.random.randint(
|
||||||
|
_start_step + 2, int(self.episode_steps * 0.8)
|
||||||
|
)
|
||||||
|
ier_end_step = self.episode_steps
|
||||||
|
ier_source_node_id = node.get_id()
|
||||||
|
# Randomise the load, as a percentage of a random link bandwith
|
||||||
|
ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice(
|
||||||
|
bandwidths
|
||||||
|
)
|
||||||
|
ier_protocol = pol_service_name # Same protocol as compromised node
|
||||||
|
ier_service = node.get_services()[
|
||||||
|
pol_service_name
|
||||||
|
] # same service as defined in the pol
|
||||||
|
ier_port = ier_service.get_port()
|
||||||
|
ier_mission_criticality = (
|
||||||
|
0 # Red IER will never be important to green agent success
|
||||||
|
)
|
||||||
|
# We choose a node to attack based on the first that applies:
|
||||||
|
# a. Green IERs, select dest node of the red ier based on dest node of green IER
|
||||||
|
# b. Attack a random server that doesn't have a DENY acl rule in default config
|
||||||
|
# c. Attack a random server
|
||||||
|
possible_ier_destinations = [
|
||||||
|
ier.get_dest_node_id()
|
||||||
|
for ier in list(self.green_iers.values())
|
||||||
|
if ier.get_source_node_id() == node.get_id()
|
||||||
|
]
|
||||||
|
if len(possible_ier_destinations) < 1:
|
||||||
|
for server in servers:
|
||||||
|
if not self.acl.is_blocked(
|
||||||
|
node.get_ip_address(),
|
||||||
|
server.ip_address,
|
||||||
|
ier_service,
|
||||||
|
ier_port,
|
||||||
|
):
|
||||||
|
possible_ier_destinations.append(server.node_id)
|
||||||
|
if len(possible_ier_destinations) < 1:
|
||||||
|
# If still none found choose from all servers
|
||||||
|
possible_ier_destinations = [server.node_id for server in servers]
|
||||||
|
ier_dest = np.random.choice(possible_ier_destinations)
|
||||||
|
self.red_iers[ier_id] = IER(
|
||||||
|
ier_id,
|
||||||
|
ier_start_step,
|
||||||
|
ier_end_step,
|
||||||
|
ier_load,
|
||||||
|
ier_protocol,
|
||||||
|
ier_port,
|
||||||
|
ier_source_node_id,
|
||||||
|
ier_dest,
|
||||||
|
ier_mission_criticality,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3: Make sure the targetted node can be set to overwhelmed - with node pol
|
||||||
|
# TODO remove duplicate red pol for same targetted service - must take into account start step
|
||||||
|
|
||||||
|
o_pol_id = str(3000 + n)
|
||||||
|
o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched
|
||||||
|
o_pol_end_step = (
|
||||||
|
self.episode_steps
|
||||||
|
) # Can become compromised at any timestep after start
|
||||||
|
o_pol_node_id = ier_dest # Node effected is the one targetted by the IER
|
||||||
|
o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes
|
||||||
|
o_pol_service_name = (
|
||||||
|
ier_protocol # Same protocol/service as the IER uses to attack
|
||||||
|
)
|
||||||
|
o_pol_new_state = SoftwareState["OVERWHELMED"]
|
||||||
|
o_pol_entry_node = False # Assumes servers are not entry nodes
|
||||||
|
o_red_pol = NodeStateInstructionRed(
|
||||||
|
_id,
|
||||||
|
_start_step,
|
||||||
|
_end_step,
|
||||||
|
_target_node_id,
|
||||||
|
_pol_initiator,
|
||||||
|
_pol_type,
|
||||||
|
pol_protocol,
|
||||||
|
_pol_state,
|
||||||
|
_pol_source_node_id,
|
||||||
|
_pol_source_node_service,
|
||||||
|
_pol_source_node_service_state,
|
||||||
|
)
|
||||||
|
self.red_node_pol[o_pol_id] = o_red_pol
|
||||||
|
|||||||
@@ -137,3 +137,20 @@ class NodeStateInstructionRed(object):
|
|||||||
The source node service state
|
The source node service state
|
||||||
"""
|
"""
|
||||||
return self.source_node_service_state
|
return self.source_node_service_state
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}("
|
||||||
|
f"id={self.id}, "
|
||||||
|
f"start_step={self.start_step}, "
|
||||||
|
f"end_step={self.end_step}, "
|
||||||
|
f"target_node_id={self.target_node_id}, "
|
||||||
|
f"initiator={self.initiator}, "
|
||||||
|
f"pol_type={self.pol_type}, "
|
||||||
|
f"service_name={self.service_name}, "
|
||||||
|
f"state={self.state}, "
|
||||||
|
f"source_node_id={self.source_node_id}, "
|
||||||
|
f"source_node_service={self.source_node_service}, "
|
||||||
|
f"source_node_service_state={self.source_node_service_state}"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
96
tests/config/random_agent_main_config.yaml
Normal file
96
tests/config/random_agent_main_config.yaml
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agent_identifier: GENERIC
|
||||||
|
#
|
||||||
|
red_agent_identifier: RANDOM
|
||||||
|
# Sets How the Action Space is defined:
|
||||||
|
# "NODE"
|
||||||
|
# "ACL"
|
||||||
|
# "ANY" node and acl actions
|
||||||
|
action_type: ANY
|
||||||
|
# Number of episodes to run per session
|
||||||
|
num_episodes: 1
|
||||||
|
# Number of time_steps per episode
|
||||||
|
num_steps: 5
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
time_delay: 1
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
session_type: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
load_agent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observation_space_high_value: 1_000_000_000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
all_ok: 0
|
||||||
|
# Node Hardware State
|
||||||
|
off_should_be_on: -10
|
||||||
|
off_should_be_resetting: -5
|
||||||
|
on_should_be_off: -2
|
||||||
|
on_should_be_resetting: -5
|
||||||
|
resetting_should_be_on: -5
|
||||||
|
resetting_should_be_off: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
good_should_be_patching: 2
|
||||||
|
good_should_be_compromised: 5
|
||||||
|
good_should_be_overwhelmed: 5
|
||||||
|
patching_should_be_good: -5
|
||||||
|
patching_should_be_compromised: 2
|
||||||
|
patching_should_be_overwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromised_should_be_good: -20
|
||||||
|
compromised_should_be_patching: -20
|
||||||
|
compromised_should_be_overwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmed_should_be_good: -20
|
||||||
|
overwhelmed_should_be_patching: -20
|
||||||
|
overwhelmed_should_be_compromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
good_should_be_repairing: 2
|
||||||
|
good_should_be_restoring: 2
|
||||||
|
good_should_be_corrupt: 5
|
||||||
|
good_should_be_destroyed: 10
|
||||||
|
repairing_should_be_good: -5
|
||||||
|
repairing_should_be_restoring: 2
|
||||||
|
repairing_should_be_corrupt: 2
|
||||||
|
repairing_should_be_destroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoring_should_be_good: -10
|
||||||
|
restoring_should_be_repairing: -2
|
||||||
|
restoring_should_be_corrupt: 1
|
||||||
|
restoring_should_be_destroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corrupt_should_be_good: -10
|
||||||
|
corrupt_should_be_repairing: -10
|
||||||
|
corrupt_should_be_restoring: -10
|
||||||
|
corrupt_should_be_destroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyed_should_be_good: -20
|
||||||
|
destroyed_should_be_repairing: -20
|
||||||
|
destroyed_should_be_restoring: -20
|
||||||
|
destroyed_should_be_corrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
red_ier_running: -5
|
||||||
|
green_ier_blocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
os_patching_duration: 5 # The time taken to patch the OS
|
||||||
|
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||||
|
service_patching_duration: 5 # The time taken to patch a service
|
||||||
|
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||||
|
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||||
|
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||||
74
tests/test_red_random_agent_behaviour.py
Normal file
74
tests/test_red_random_agent_behaviour.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from datetime import time, datetime
|
||||||
|
|
||||||
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
from tests import TEST_CONFIG_ROOT
|
||||||
|
from tests.conftest import _get_temp_session_path
|
||||||
|
|
||||||
|
|
||||||
|
def run_generic(env, config_values):
|
||||||
|
"""Run against a generic agent."""
|
||||||
|
# Reset the environment at the start of the episode
|
||||||
|
env.reset()
|
||||||
|
for episode in range(0, config_values.num_episodes):
|
||||||
|
for step in range(0, config_values.num_steps):
|
||||||
|
# Send the observation space to the agent to get an action
|
||||||
|
# TEMP - random action for now
|
||||||
|
# action = env.blue_agent_action(obs)
|
||||||
|
# action = env.action_space.sample()
|
||||||
|
action = 0
|
||||||
|
|
||||||
|
# Run the simulation step on the live environment
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Break if done is True
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Introduce a delay between steps
|
||||||
|
time.sleep(config_values.time_delay / 1000)
|
||||||
|
|
||||||
|
# Reset the environment at the end of the episode
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_red_agent_behaviour():
|
||||||
|
"""
|
||||||
|
Test that hardware state is penalised at each step.
|
||||||
|
|
||||||
|
When the initial state is OFF compared to reference state which is ON.
|
||||||
|
"""
|
||||||
|
list_of_node_instructions = []
|
||||||
|
for i in range(2):
|
||||||
|
|
||||||
|
"""Takes a config path and returns the created instance of Primaite."""
|
||||||
|
session_timestamp: datetime = datetime.now()
|
||||||
|
session_path = _get_temp_session_path(session_timestamp)
|
||||||
|
|
||||||
|
timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
env = Primaite(
|
||||||
|
training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
|
||||||
|
lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml",
|
||||||
|
transaction_list=[],
|
||||||
|
session_path=session_path,
|
||||||
|
timestamp_str=timestamp_str,
|
||||||
|
)
|
||||||
|
training_config = env.training_config
|
||||||
|
training_config.num_steps = env.episode_steps
|
||||||
|
|
||||||
|
# TOOD: This needs t be refactored to happen outside. Should be part of
|
||||||
|
# a main Session class.
|
||||||
|
if training_config.agent_identifier == "GENERIC":
|
||||||
|
run_generic(env, training_config)
|
||||||
|
all_red_actions = env.red_node_pol
|
||||||
|
list_of_node_instructions.append(all_red_actions)
|
||||||
|
|
||||||
|
# assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1]))
|
||||||
|
print(list_of_node_instructions[0]["1"].get_start_step())
|
||||||
|
print(list_of_node_instructions[0]["1"].get_end_step())
|
||||||
|
print(list_of_node_instructions[0]["1"].get_target_node_id())
|
||||||
|
print(list_of_node_instructions[1]["1"].get_start_step())
|
||||||
|
print(list_of_node_instructions[1]["1"].get_end_step())
|
||||||
|
print(list_of_node_instructions[1]["1"].get_target_node_id())
|
||||||
|
assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1])
|
||||||
Reference in New Issue
Block a user