#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -77,9 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(
observation_shape, dtype=self._DATA_TYPE
)
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
@@ -94,12 +92,8 @@ class NodeLinkTable(AbstractObservationComponent):
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][
2
] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][3] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
@@ -107,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent):
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
self.current_observation[item_index][service_index] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
@@ -129,9 +121,7 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
@@ -203,9 +193,7 @@ class NodeStatuses(AbstractObservationComponent):
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(
service
).value
service_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
@@ -269,11 +257,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels]
* self.env.num_links
* self._entries_per_link
)
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
@@ -292,9 +276,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [
protocol.get_load() for protocol in link.protocol_list
]
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
@@ -302,9 +284,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 2)
) + 1
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
obs.append(int(traffic_level))

View File

@@ -12,13 +12,11 @@ from matplotlib import pyplot as plt
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import (
is_valid_acl_action_extra,
is_valid_node_action,
)
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
AgentFramework,
FileSystemState,
HardwareState,
NodePOLInitiator,
@@ -37,18 +35,13 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import (
NodeStateInstructionGreen,
)
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import (
apply_red_agent_iers,
apply_red_agent_node_pol,
)
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
@@ -85,9 +78,7 @@ class Primaite(Env):
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self.training_config: TrainingConfig = training_config.load(
training_config_path
)
self.training_config: TrainingConfig = training_config.load(training_config_path)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
@@ -238,25 +229,22 @@ class Primaite(Env):
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.error(
f"Invalid action type selected: {self.training_config.action_type}"
)
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
self.episode_av_reward_writer = SessionOutputWriter(
self, transaction_writer=False, learning_session=True
)
self.transaction_writer = SessionOutputWriter(
self, transaction_writer=True, learning_session=True
)
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
@property
def actual_episode_count(self) -> int:
"""Shifts the episode_count by -1 for RLlib."""
if self.training_config.agent_framework is AgentFramework.RLLIB:
return self.episode_count - 1
return self.episode_count
def set_as_eval(self):
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(
self, transaction_writer=False, learning_session=False
)
self.transaction_writer = SessionOutputWriter(
self, transaction_writer=True, learning_session=False
)
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
self.episode_count = 0
self.step_count = 0
self.total_step_count = 0
@@ -268,8 +256,8 @@ class Primaite(Env):
Returns:
Environment observation space (reset)
"""
if self.episode_count > 0:
csv_data = self.episode_count, self.average_reward
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
self.episode_count += 1
@@ -291,6 +279,7 @@ class Primaite(Env):
# Update observations space and return
self.update_environent_obs()
return self.env_obs
def step(self, action):
@@ -319,9 +308,7 @@ class Primaite(Env):
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
self.agent_identifier, self.episode_count, self.step_count
)
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
# Load the initial observation space into the transaction
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
# Load the action space into the transaction
@@ -350,9 +337,7 @@ class Primaite(Env):
self.nodes_post_pol = copy.deepcopy(self.nodes)
self.links_post_pol = copy.deepcopy(self.links)
# Reference
apply_node_pol(
self.nodes_reference, self.node_pol, self.step_count
) # Node PoL
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
apply_iers(
self.network_reference,
self.nodes_reference,
@@ -371,9 +356,7 @@ class Primaite(Env):
self.acl,
self.step_count,
)
apply_red_agent_node_pol(
self.nodes, self.red_iers, self.red_node_pol, self.step_count
)
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
# Take snapshots of nodes and links
self.nodes_post_red = copy.deepcopy(self.nodes)
self.links_post_red = copy.deepcopy(self.links)
@@ -389,11 +372,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
_LOGGER.debug(
f"Episode: {self.episode_count}, "
f"Step {self.step_count}, "
f"Reward: {reward}"
)
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -401,10 +380,7 @@ class Primaite(Env):
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
_LOGGER.info(
f"Episode: {self.episode_count}, "
f"Average Reward: {self.average_reward}"
)
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
# Load the reward into the transaction
transaction.reward = reward
@@ -417,11 +393,21 @@ class Primaite(Env):
transaction.obs_space_post = copy.deepcopy(self.env_obs)
# Write transaction to file
self.transaction_writer.write(transaction)
if self.actual_episode_count > 0:
self.transaction_writer.write(transaction)
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
"""Override parent close and close writers."""
# Close files if last episode/step
# if self.can_finish:
super().close()
self.transaction_writer.close()
self.episode_av_reward_writer.close()
def init_acl(self):
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
@@ -431,12 +417,7 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
_LOGGER.debug(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
+ str(protocol.get_load())
)
_LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
"""
@@ -450,13 +431,9 @@ class Primaite(Env):
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
_LOGGER.error("Invalid action type found")
@@ -541,10 +518,7 @@ class Primaite(Env):
elif property_action == 2:
# Repair
# You cannot repair a destroyed file system - it needs restoring
if (
node.file_system_state_actual
!= FileSystemState.DESTROYED
):
if node.file_system_state_actual != FileSystemState.DESTROYED:
node.set_file_system_state(FileSystemState.REPAIRING)
elif property_action == 3:
# Restore
@@ -587,9 +561,7 @@ class Primaite(Env):
acl_rule_source = "ANY"
else:
node = list(self.nodes.values())[action_source_ip - 1]
if isinstance(node, ServiceNode) or isinstance(
node, ActiveNode
):
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
acl_rule_source = node.ip_address
else:
return
@@ -598,9 +570,7 @@ class Primaite(Env):
acl_rule_destination = "ANY"
else:
node = list(self.nodes.values())[action_destination_ip - 1]
if isinstance(node, ServiceNode) or isinstance(
node, ActiveNode
):
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
acl_rule_destination = node.ip_address
else:
return
@@ -685,9 +655,7 @@ class Primaite(Env):
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
self.obs_handler = ObservationsHandler.from_config(
self, self.obs_config
)
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
return self.obs_handler.space, self.obs_handler.current_observation
@@ -794,9 +762,7 @@ class Primaite(Env):
service_protocol = service["name"]
service_port = service["port"]
service_state = SoftwareState[service["state"]]
node.add_service(
Service(service_protocol, service_port, service_state)
)
node.add_service(Service(service_protocol, service_port, service_state))
else:
# Bad formatting
pass
@@ -849,9 +815,7 @@ class Primaite(Env):
dest_node_ref: Node = self.nodes_reference[link_destination]
# Add link to network (reference)
self.network_reference.add_edge(
source_node_ref, dest_node_ref, id=link_name
)
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
# Add link to link dictionary (reference)
self.links_reference[link_name] = Link(
@@ -1126,9 +1090,7 @@ class Primaite(Env):
# All nodes have these parameters
node_id = item["node_id"]
node_class = item["node_class"]
node_hardware_state: HardwareState = HardwareState[
item["hardware_state"]
]
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
node: NodeUnion = self.nodes[node_id]
node_ref = self.nodes_reference[node_id]
@@ -1249,11 +1211,7 @@ class Primaite(Env):
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}

View File

@@ -41,29 +41,19 @@ def calculate_reward_function(
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(
final_node, ServiceNode
):
reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
@@ -82,10 +72,7 @@ def calculate_reward_function(
if step_count >= start_step and step_count <= stop_step:
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
)
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
if live_blocked and not reference_blocked:
reward_value += ier_reward
@@ -107,9 +94,7 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(
final_node, initial_node, reference_node, config_values
):
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the hardware state of a node.
@@ -158,9 +143,7 @@ def score_node_operating_state(
return score
def score_node_os_state(
final_node, initial_node, reference_node, config_values
):
def score_node_os_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the Software State of a node.
@@ -211,9 +194,7 @@ def score_node_os_state(
return score
def score_node_service_state(
final_node, initial_node, reference_node, config_values
):
def score_node_service_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the service state(s) of a node.
@@ -285,9 +266,7 @@ def score_node_service_state(
return score
def score_node_file_system(
final_node, initial_node, reference_node, config_values
):
def score_node_file_system(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the file system state of a node.