#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -5,7 +5,7 @@ import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Optional, Union
from typing import Dict, Final, Union
from uuid import uuid4
import yaml
@@ -51,16 +51,12 @@ class AgentSessionABC(ABC):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path
)
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(
self._lay_down_config_path
)
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.output_verbose_level = self._training_config.output_verbose_level
self._env: Primaite
@@ -132,9 +128,7 @@ class AgentSessionABC(ABC):
"learning": {"total_episodes": None, "total_time_steps": None},
"evaluation": {"total_episodes": None, "total_time_steps": None},
"env": {
"training_config": self._training_config.to_dict(
json_serializable=True
),
"training_config": self._training_config.to_dict(json_serializable=True),
"lay_down_config": self._lay_down_config,
},
}
@@ -161,19 +155,11 @@ class AgentSessionABC(ABC):
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"][
"total_episodes"
] = self._env.episode_count # noqa
metadata_dict["learning"][
"total_time_steps"
] = self._env.total_step_count # noqa
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"][
"total_episodes"
] = self._env.episode_count # noqa
metadata_dict["evaluation"][
"total_time_steps"
] = self._env.total_step_count # noqa
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -184,12 +170,9 @@ class AgentSessionABC(ABC):
@abstractmethod
def _setup(self):
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment "
f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(
f"The output directory for this session is: {self.session_path}"
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@@ -201,17 +184,11 @@ class AgentSessionABC(ABC):
@abstractmethod
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_learn:
@@ -225,17 +202,11 @@ class AgentSessionABC(ABC):
@abstractmethod
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
@@ -281,9 +252,7 @@ class AgentSessionABC(ABC):
else:
# Session path does not exist
msg = (
f"Failed to load PrimAITE Session, path does not exist: {path}"
)
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@@ -354,17 +323,11 @@ class HardCodedAgentSessionABC(AgentSessionABC):
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@@ -375,27 +338,19 @@ class HardCodedAgentSessionABC(AgentSessionABC):
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
if not time_steps:
time_steps = self._training_config.num_steps
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
if not episodes:
episodes = self._training_config.num_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation

View File

@@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
if (
self._training_config.hard_coded_agent_view
== HardCodedAgentView.BASIC
):
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
@@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(
source_node_address, dest_node_address, protocol, port
):
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
@@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
@@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[
protocol - 1
] # -1 as dont have to account for ANY in list of services
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
@@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(
get_node_of_ip(action_source_ip, self._env.nodes)
)
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(
get_node_of_ip(
action_destination_ip, self._env.nodes
)
)
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name)
+ 1
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
@@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(
self._env.green_iers, self._env.acl, self._env.nodes
)
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(
ier, self._env.acl, self._env.nodes
)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(
service_name_to_check
)
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
@@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len(
[i for i in o if i != "NONE"]
) # number of nodes (not links)
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [
n for n, i in enumerate(service_states) if i == "COMPROMISED"
]
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = (
np.random.choice(comprimised_states) + 1
) # +1 as 0 would be any
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(
list(range(1, number_of_nodes + 1)) + ["ANY"]
)
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip)
if action_destination_ip != "ANY"
else action_destination_ip
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block

View File

@@ -1,9 +1,5 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_action_node_enum,
transform_change_obs_readable,
)
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
class HardCodedNodeAgent(HardCodedAgentSessionABC):
@@ -93,12 +89,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = (
"ON" # Why reset it when we can just turn it on
)
action_service_index = (
0 # does nothing isn't relevant for operating state
)
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [
action_node_id,
action_node_property,

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
from typing import Union
import tensorflow as tf
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
@@ -14,11 +13,7 @@ from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import (
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
)
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
@@ -49,10 +44,7 @@ class RLlibAgent(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = (
f"Expected RLLIB agent_framework, "
f"got {self._training_config.agent_framework}"
)
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
@@ -60,10 +52,7 @@ class RLlibAgent(AgentSessionABC):
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = (
"Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}"
)
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: PPOConfig
@@ -94,12 +83,8 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result[
"episodes_total"
]
metadata_dict["total_time_steps"] = self._current_result[
"timesteps_total"
]
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -122,9 +107,7 @@ class RLlibAgent(AgentSessionABC):
),
)
self._agent_config.training(
train_batch_size=self._training_config.num_steps
)
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
@@ -132,72 +115,41 @@ class RLlibAgent(AgentSessionABC):
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
)
self._agent: Algorithm = self._agent_config.build(
logger_creator=_custom_log_creator(self.session_path)
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (
episode_count == self._training_config.num_episodes
):
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
# Temporarily override train_batch_size and horizon
if time_steps:
self._agent_config.train_batch_size = time_steps
self._agent_config.horizon = time_steps
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
if not episodes:
episodes = self._training_config.num_episodes
_LOGGER.info(
f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps..."
)
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
if (
self._training_config.deep_learning_framework
!= DeepLearningFramework.TORCH
):
policy = self._agent.get_policy()
tf.compat.v1.summary.FileWriter(
self.session_path / "ray_results", policy.get_session().graph
)
super().learn()
self._agent.stop()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional, Union
from typing import Union
import numpy as np
from stable_baselines3 import A2C, PPO
@@ -21,10 +21,7 @@ class SB3Agent(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = (
f"Expected SB3 agent_framework, "
f"got {self._training_config.agent_framework}"
)
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
@@ -32,10 +29,7 @@ class SB3Agent(AgentSessionABC):
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = (
"Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier}"
)
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
_LOGGER.error(msg)
raise ValueError(msg)
@@ -64,19 +58,15 @@ class SB3Agent(AgentSessionABC):
self._env,
verbose=self.output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path,
tensorboard_log=str(self._tensorboard_log_path),
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (
episode_count == self._training_config.num_episodes
):
checkpoint_path = (
self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
)
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
@@ -85,58 +75,37 @@ class SB3Agent(AgentSessionABC):
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
self.is_eval = False
_LOGGER.info(
f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps..."
)
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self.close()
self._env.reset()
self._env.close()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:
@@ -144,19 +113,18 @@ class SB3Agent(AgentSessionABC):
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for "
f"{episodes} episodes @ {time_steps} time steps..."
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(
obs, deterministic=deterministic
)
action, _states = self._agent.predict(obs, deterministic=deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
@classmethod

View File

@@ -1,9 +1,5 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_action_acl_enum,
transform_action_node_enum,
)
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
class RandomAgent(HardCodedAgentSessionABC):

View File

@@ -24,9 +24,7 @@ def transform_action_node_readable(action):
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (
action_node_property == "OS" or action_node_property == "SERVICE"
) and action[2] <= 1:
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -117,11 +115,7 @@ def is_valid_acl_action(action):
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
@@ -173,9 +167,7 @@ def transform_change_obs_readable(obs):
for service in range(3, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
]
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
new_obs.append(service_states)
return new_obs
@@ -247,9 +239,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
return new_obs
def describe_obs_change(
obs1, obs2, num_nodes=10, num_links=10, num_services=1
):
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""
Return string describing change between two observations.
@@ -291,16 +281,9 @@ def _describe_obs_change_helper(obs_change, is_link):
TODO: Typehint params and return.
"""
# Indexes where a change has occured, not including 0th index
index_changed = [
i for i in range(1, len(obs_change)) if obs_change[i] != -1
]
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [
NodePOLType(i).name
if i < 3
else NodePOLType(3).name + " " + str(i - 3)
for i in index_changed
]
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
@@ -367,9 +350,7 @@ def transform_action_node_readable(action):
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (
action_node_property == "OS" or action_node_property == "SERVICE"
) and action[2] <= 1:
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -397,9 +378,7 @@ def node_action_description(action):
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = (
f"NODE {node_id}, {node_property}, SET TO {property_action}"
)
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
elif node_property == "SERVICE":
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
else:
@@ -522,11 +501,7 @@ def is_valid_acl_action(action):
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":