#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -77,9 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(
|
||||
observation_shape, dtype=self._DATA_TYPE
|
||||
)
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
@@ -94,12 +92,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][
|
||||
2
|
||||
] = node.software_state.value
|
||||
self.current_observation[item_index][
|
||||
3
|
||||
] = node.file_system_state_observed.value
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
@@ -107,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
self.current_observation[item_index][service_index] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
@@ -129,9 +121,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][
|
||||
protocol_index + 4
|
||||
] = protocol.get_load()
|
||||
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
@@ -203,9 +193,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(
|
||||
service
|
||||
).value
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
@@ -269,11 +257,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = (
|
||||
[self._quantisation_levels]
|
||||
* self.env.num_links
|
||||
* self._entries_per_link
|
||||
)
|
||||
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
@@ -292,9 +276,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [
|
||||
protocol.get_load() for protocol in link.protocol_list
|
||||
]
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
@@ -302,9 +284,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (
|
||||
1 / (self._quantisation_levels - 2)
|
||||
) + 1
|
||||
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@ from matplotlib import pyplot as plt
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import (
|
||||
is_valid_acl_action_extra,
|
||||
is_valid_node_action,
|
||||
)
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
@@ -37,18 +35,13 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import (
|
||||
NodeStateInstructionGreen,
|
||||
)
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import (
|
||||
apply_red_agent_iers,
|
||||
apply_red_agent_node_pol,
|
||||
)
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
@@ -85,9 +78,7 @@ class Primaite(Env):
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(
|
||||
training_config_path
|
||||
)
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
@@ -238,25 +229,22 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.error(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=True
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
"""Shifts the episode_count by -1 for RLlib."""
|
||||
if self.training_config.agent_framework is AgentFramework.RLLIB:
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self):
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=False
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=False
|
||||
)
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
||||
self.episode_count = 0
|
||||
self.step_count = 0
|
||||
self.total_step_count = 0
|
||||
@@ -268,8 +256,8 @@ class Primaite(Env):
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
"""
|
||||
if self.episode_count > 0:
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
if self.actual_episode_count > 0:
|
||||
csv_data = self.actual_episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
self.episode_count += 1
|
||||
@@ -291,6 +279,7 @@ class Primaite(Env):
|
||||
|
||||
# Update observations space and return
|
||||
self.update_environent_obs()
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
@@ -319,9 +308,7 @@ class Primaite(Env):
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
|
||||
# Load the action space into the transaction
|
||||
@@ -350,9 +337,7 @@ class Primaite(Env):
|
||||
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
# Reference
|
||||
apply_node_pol(
|
||||
self.nodes_reference, self.node_pol, self.step_count
|
||||
) # Node PoL
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
@@ -371,9 +356,7 @@ class Primaite(Env):
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_node_pol(
|
||||
self.nodes, self.red_iers, self.red_node_pol, self.step_count
|
||||
)
|
||||
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_red = copy.deepcopy(self.nodes)
|
||||
self.links_post_red = copy.deepcopy(self.links)
|
||||
@@ -389,11 +372,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
_LOGGER.debug(
|
||||
f"Episode: {self.episode_count}, "
|
||||
f"Step {self.step_count}, "
|
||||
f"Reward: {reward}"
|
||||
)
|
||||
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -401,10 +380,7 @@ class Primaite(Env):
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
_LOGGER.info(
|
||||
f"Episode: {self.episode_count}, "
|
||||
f"Average Reward: {self.average_reward}"
|
||||
)
|
||||
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
|
||||
# Load the reward into the transaction
|
||||
transaction.reward = reward
|
||||
|
||||
@@ -417,11 +393,21 @@ class Primaite(Env):
|
||||
transaction.obs_space_post = copy.deepcopy(self.env_obs)
|
||||
|
||||
# Write transaction to file
|
||||
self.transaction_writer.write(transaction)
|
||||
if self.actual_episode_count > 0:
|
||||
self.transaction_writer.write(transaction)
|
||||
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
"""Override parent close and close writers."""
|
||||
# Close files if last episode/step
|
||||
# if self.can_finish:
|
||||
super().close()
|
||||
|
||||
self.transaction_writer.close()
|
||||
self.episode_av_reward_writer.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
@@ -431,12 +417,7 @@ class Primaite(Env):
|
||||
for link_key, link_value in self.links.items():
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
_LOGGER.debug(
|
||||
" Protocol: "
|
||||
+ protocol.get_name().name
|
||||
+ ", Load: "
|
||||
+ str(protocol.get_load())
|
||||
)
|
||||
_LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""
|
||||
@@ -450,13 +431,9 @@ class Primaite(Env):
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
_LOGGER.error("Invalid action type found")
|
||||
@@ -541,10 +518,7 @@ class Primaite(Env):
|
||||
elif property_action == 2:
|
||||
# Repair
|
||||
# You cannot repair a destroyed file system - it needs restoring
|
||||
if (
|
||||
node.file_system_state_actual
|
||||
!= FileSystemState.DESTROYED
|
||||
):
|
||||
if node.file_system_state_actual != FileSystemState.DESTROYED:
|
||||
node.set_file_system_state(FileSystemState.REPAIRING)
|
||||
elif property_action == 3:
|
||||
# Restore
|
||||
@@ -587,9 +561,7 @@ class Primaite(Env):
|
||||
acl_rule_source = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_source_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(
|
||||
node, ActiveNode
|
||||
):
|
||||
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
||||
acl_rule_source = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -598,9 +570,7 @@ class Primaite(Env):
|
||||
acl_rule_destination = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_destination_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(
|
||||
node, ActiveNode
|
||||
):
|
||||
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
||||
acl_rule_destination = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -685,9 +655,7 @@ class Primaite(Env):
|
||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||
"""
|
||||
self.obs_handler = ObservationsHandler.from_config(
|
||||
self, self.obs_config
|
||||
)
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
@@ -794,9 +762,7 @@ class Primaite(Env):
|
||||
service_protocol = service["name"]
|
||||
service_port = service["port"]
|
||||
service_state = SoftwareState[service["state"]]
|
||||
node.add_service(
|
||||
Service(service_protocol, service_port, service_state)
|
||||
)
|
||||
node.add_service(Service(service_protocol, service_port, service_state))
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
@@ -849,9 +815,7 @@ class Primaite(Env):
|
||||
dest_node_ref: Node = self.nodes_reference[link_destination]
|
||||
|
||||
# Add link to network (reference)
|
||||
self.network_reference.add_edge(
|
||||
source_node_ref, dest_node_ref, id=link_name
|
||||
)
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
|
||||
|
||||
# Add link to link dictionary (reference)
|
||||
self.links_reference[link_name] = Link(
|
||||
@@ -1126,9 +1090,7 @@ class Primaite(Env):
|
||||
# All nodes have these parameters
|
||||
node_id = item["node_id"]
|
||||
node_class = item["node_class"]
|
||||
node_hardware_state: HardwareState = HardwareState[
|
||||
item["hardware_state"]
|
||||
]
|
||||
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
|
||||
|
||||
node: NodeUnion = self.nodes[node_id]
|
||||
node_ref = self.nodes_reference[node_id]
|
||||
@@ -1249,11 +1211,7 @@ class Primaite(Env):
|
||||
|
||||
# Change node keys to not overlap with acl keys
|
||||
# Only 1 nothing action (key 0) is required, remove the other
|
||||
new_node_action_dict = {
|
||||
k + len(acl_action_dict) - 1: v
|
||||
for k, v in node_action_dict.items()
|
||||
if k != 0
|
||||
}
|
||||
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
|
||||
@@ -41,29 +41,19 @@ def calculate_reward_function(
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Hardware State
|
||||
reward_value += score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(
|
||||
final_node, ServiceNode
|
||||
):
|
||||
reward_value += score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Service State
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
@@ -82,10 +72,7 @@ def calculate_reward_function(
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
)
|
||||
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
reward_value += ier_reward
|
||||
@@ -107,9 +94,7 @@ def calculate_reward_function(
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
@@ -158,9 +143,7 @@ def score_node_operating_state(
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
@@ -211,9 +194,7 @@ def score_node_os_state(
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
@@ -285,9 +266,7 @@ def score_node_service_state(
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user