From e11fd2ced4331ebfaea4c1b75ea018c185cbf204 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 30 Jun 2023 16:52:57 +0100 Subject: [PATCH] #917 - Fixed the RLlib integration - Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files. --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 +- src/primaite/__init__.py | 27 +--- src/primaite/acl/access_control_list.py | 55 ++----- src/primaite/agents/agent.py | 71 ++------- src/primaite/agents/hardcoded_acl.py | 66 +++------ src/primaite/agents/hardcoded_node.py | 14 +- src/primaite/agents/rllib.py | 74 ++-------- src/primaite/agents/sb3.py | 66 +++------ src/primaite/agents/simple.py | 6 +- src/primaite/agents/utils.py | 43 ++---- src/primaite/cli.py | 29 +--- src/primaite/config/lay_down_config.py | 8 +- src/primaite/config/training_config.py | 25 +--- src/primaite/environment/observations.py | 38 ++--- src/primaite/environment/primaite_env.py | 136 ++++++------------ src/primaite/environment/reward.py | 41 ++---- src/primaite/links/link.py | 4 +- src/primaite/main.py | 8 +- src/primaite/nodes/active_node.py | 49 ++----- .../nodes/node_state_instruction_green.py | 4 +- .../nodes/node_state_instruction_red.py | 4 +- src/primaite/nodes/passive_node.py | 4 +- src/primaite/nodes/service_node.py | 24 +--- src/primaite/pol/green_pol.py | 66 ++------- src/primaite/pol/red_agent_pol.py | 63 ++------ src/primaite/primaite_session.py | 119 ++++----------- src/primaite/setup/reset_demo_notebooks.py | 12 +- src/primaite/setup/reset_example_configs.py | 8 +- src/primaite/transactions/transaction.py | 8 +- src/primaite/utils/session_output_writer.py | 4 +- .../legacy_training_config.yaml | 0 .../new_training_config.yaml | 0 tests/conftest.py | 10 +- tests/mock_and_patch/get_session_path_mock.py | 4 +- tests/test_acl.py | 4 +- tests/test_active_node.py | 12 +- tests/test_observation_space.py | 18 +-- tests/test_primaite_session.py | 10 +- tests/test_resetting_node.py | 16 +-- tests/test_service_node.py | 8 +- tests/test_single_action_space.py | 4 +- tests/test_training_config.py | 12 +- 43 files changed, 284 insertions(+), 896 deletions(-) rename tests/config/{legacy => legacy_conversion}/legacy_training_config.yaml (100%) rename tests/config/{legacy => legacy_conversion}/new_training_config.yaml (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26cd5697..6e435bee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: rev: 23.1.0 hooks: - id: black - args: [ "--line-length=79" ] + args: [ "--line-length=120" ] additional_dependencies: - jupyter - repo: http://github.com/pycqa/isort diff --git a/pyproject.toml b/pyproject.toml index b2957789..86418eaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,9 +72,9 @@ primaite = "primaite.cli:app" [tool.isort] profile = "black" -line_length = 79 +line_length = 120 force_sort_within_sections = "False" order_by_type = "False" [tool.black] -line-length = 79 +line-length = 120 diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index e753b4ef..030860d8 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -19,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): - config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) - ) + config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) with open(config_path, "r") as file: primaite_config = yaml.safe_load(file) log_level_map = { @@ -34,9 +30,7 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[ - primaite_config["logging"]["log_level"] - ] + primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] return primaite_config @@ -82,14 +76,9 @@ class _LevelFormatter(Formatter): super().__init__() if "fmt" in kwargs: - raise ValueError( - "Format string must be passed to level-surrogate formatters, " - "not this one" - ) + raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one") - self.formats = sorted( - (level, Formatter(fmt, **kwargs)) for level, fmt in formats.items() - ) + self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()) def format(self, record: LogRecord) -> str: """Overrides ``Formatter.format``.""" @@ -110,13 +99,9 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][ - "WARNING" - ], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][ - "CRITICAL" - ], + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"], } ) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index a147d963..3b0e9234 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,9 +10,7 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[ - str, AccessControlList - ] = {} # A dictionary of ACL Rules + self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -27,29 +25,16 @@ class AccessControlList: True if match; False otherwise. """ if ( - ( - _rule.get_source_ip() == _source_ip_address - and _rule.get_dest_ip() == _dest_ip_address - ) - or ( - _rule.get_source_ip() == "ANY" - and _rule.get_dest_ip() == _dest_ip_address - ) - or ( - _rule.get_source_ip() == _source_ip_address - and _rule.get_dest_ip() == "ANY" - ) - or ( - _rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY" - ) + (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) + or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) + or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") + or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY") ): return True else: return False - def is_blocked( - self, _source_ip_address, _dest_ip_address, _protocol, _port - ): + def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): """ Checks for rules that block a protocol / port. @@ -63,15 +48,9 @@ class AccessControlList: Indicates block if all conditions are satisfied. """ for rule_key, rule_value in self.acl.items(): - if self.check_address_match( - rule_value, _source_ip_address, _dest_ip_address - ): - if ( - rule_value.get_protocol() == _protocol - or rule_value.get_protocol() == "ANY" - ) and ( - str(rule_value.get_port()) == str(_port) - or rule_value.get_port() == "ANY" + if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): + if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and ( + str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" ): # There's a matching rule. Get the permission if rule_value.get_permission() == "DENY": @@ -93,9 +72,7 @@ class AccessControlList: _protocol: the protocol _port: the port """ - new_rule = ACLRule( - _permission, _source_ip, _dest_ip, _protocol, str(_port) - ) + new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(new_rule) self.acl[hash_value] = new_rule @@ -110,9 +87,7 @@ class AccessControlList: _protocol: the protocol _port: the port """ - rule = ACLRule( - _permission, _source_ip, _dest_ip, _protocol, str(_port) - ) + rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) # There will not always be something 'popable' since the agent will be trying random things try: @@ -124,9 +99,7 @@ class AccessControlList: """Removes all rules.""" self.acl.clear() - def get_dictionary_hash( - self, _permission, _source_ip, _dest_ip, _protocol, _port - ): + def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): """ Produces a hash value for a rule. @@ -140,8 +113,6 @@ class AccessControlList: Returns: Hash value based on rule parameters. """ - rule = ACLRule( - _permission, _source_ip, _dest_ip, _protocol, str(_port) - ) + rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) return hash_value diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index c76583c0..90eb2b66 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -5,7 +5,7 @@ import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Dict, Final, Optional, Union +from typing import Dict, Final, Union from uuid import uuid4 import yaml @@ -51,16 +51,12 @@ class AgentSessionABC(ABC): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load( - self._training_config_path - ) + self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load( - self._lay_down_config_path - ) + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) self.output_verbose_level = self._training_config.output_verbose_level self._env: Primaite @@ -132,9 +128,7 @@ class AgentSessionABC(ABC): "learning": {"total_episodes": None, "total_time_steps": None}, "evaluation": {"total_episodes": None, "total_time_steps": None}, "env": { - "training_config": self._training_config.to_dict( - json_serializable=True - ), + "training_config": self._training_config.to_dict(json_serializable=True), "lay_down_config": self._lay_down_config, }, } @@ -161,19 +155,11 @@ class AgentSessionABC(ABC): metadata_dict["end_datetime"] = datetime.now().isoformat() if not self.is_eval: - metadata_dict["learning"][ - "total_episodes" - ] = self._env.episode_count # noqa - metadata_dict["learning"][ - "total_time_steps" - ] = self._env.total_step_count # noqa + metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa else: - metadata_dict["evaluation"][ - "total_episodes" - ] = self._env.episode_count # noqa - metadata_dict["evaluation"][ - "total_time_steps" - ] = self._env.total_step_count # noqa + metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -184,12 +170,9 @@ class AgentSessionABC(ABC): @abstractmethod def _setup(self): _LOGGER.info( - "Welcome to the Primary-level AI Training Environment " - f"(PrimAITE) (version: {primaite.__version__})" - ) - _LOGGER.info( - f"The output directory for this session is: {self.session_path}" + "Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})" ) + _LOGGER.info(f"The output directory for this session is: {self.session_path}") self._write_session_metadata_file() self._can_learn = True self._can_evaluate = False @@ -201,17 +184,11 @@ class AgentSessionABC(ABC): @abstractmethod def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ if self._can_learn: @@ -225,17 +202,11 @@ class AgentSessionABC(ABC): @abstractmethod def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ self._env.set_as_eval() # noqa @@ -281,9 +252,7 @@ class AgentSessionABC(ABC): else: # Session path does not exist - msg = ( - f"Failed to load PrimAITE Session, path does not exist: {path}" - ) + msg = f"Failed to load PrimAITE Session, path does not exist: {path}" _LOGGER.error(msg) raise FileNotFoundError(msg) pass @@ -354,17 +323,11 @@ class HardCodedAgentSessionABC(AgentSessionABC): def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ _LOGGER.warning("Deterministic agents cannot learn") @@ -375,27 +338,19 @@ class HardCodedAgentSessionABC(AgentSessionABC): def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ self._env.set_as_eval() # noqa self.is_eval = True - if not time_steps: - time_steps = self._training_config.num_steps + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes - if not episodes: - episodes = self._training_config.num_episodes obs = self._env.reset() for episode in range(episodes): # Reset env and collect initial observation diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index f70320f1..263ccbdc 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic ACL agent.""" def _calculate_action(self, obs): - if ( - self._training_config.hard_coded_agent_view - == HardCodedAgentView.BASIC - ): + if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: # Basic view action using only the current observation return self._calculate_action_basic_view(obs) else: @@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): port = green_ier.get_port() # Can be blocked by an ACL or by default (no allow rule exists) - if acl.is_blocked( - source_node_address, dest_node_address, protocol, port - ): + if acl.is_blocked(source_node_address, dest_node_address, protocol, port): blocked_green_iers[green_ier_id] = green_ier return blocked_green_iers @@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - matching_rules = acl.get_relevant_rules( - source_node_address, dest_node_address, protocol, port - ) + matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): @@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): dest_node_address = dest_node_id if protocol != "ANY": - protocol = services_list[ - protocol - 1 - ] # -1 as dont have to account for ANY in list of services + protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services - matching_rules = acl.get_relevant_rules( - source_node_address, dest_node_address, protocol, port - ) + matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules def get_allow_acl_rules( @@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action_decision = "DELETE" action_permission = "ALLOW" action_source_ip = rule.get_source_ip() - action_source_id = int( - get_node_of_ip(action_source_ip, self._env.nodes) - ) + action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes)) action_destination_ip = rule.get_dest_ip() - action_destination_id = int( - get_node_of_ip( - action_destination_ip, self._env.nodes - ) - ) + action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes)) action_protocol_name = rule.get_protocol() action_protocol = ( - self._env.services_list.index(action_protocol_name) - + 1 + self._env.services_list.index(action_protocol_name) + 1 ) # convert name e.g. 'TCP' to index action_port_name = rule.get_port() action_port = ( @@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if not found_action: # Which Green IERS are blocked - blocked_green_iers = self.get_blocked_green_iers( - self._env.green_iers, self._env.acl, self._env.nodes - ) + blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes) for ier_key, ier in blocked_green_iers.items(): # Which ALLOW rules are allowing this IER (none) - allowing_rules = self.get_allow_acl_rules_for_ier( - ier, self._env.acl, self._env.nodes - ) + allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes) # If there are no blocking rules, it may be being blocked by default # If there is already an allow rule node_id_to_check = int(ier.get_source_node_id()) service_name_to_check = ier.get_protocol() - service_id_to_check = self._env.services_list.index( - service_name_to_check - ) + service_id_to_check = self._env.services_list.index(service_name_to_check) # Service state of the the source node in the ier service_state = s[service_id_to_check][node_id_to_check - 1] @@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if len(r_obs) == 4: # only 1 service s = [*s] - number_of_nodes = len( - [i for i in o if i != "NONE"] - ) # number of nodes (not links) + number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links) for service_num, service_states in enumerate(s): - comprimised_states = [ - n for n, i in enumerate(service_states) if i == "COMPROMISED" - ] + comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"] if len(comprimised_states) == 0: # No states are COMPROMISED, try the next service continue - compromised_node = ( - np.random.choice(comprimised_states) + 1 - ) # +1 as 0 would be any + compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any action_decision = "DELETE" action_permission = "ALLOW" action_source_ip = compromised_node # Randomly select a destination ID to block - action_destination_ip = np.random.choice( - list(range(1, number_of_nodes + 1)) + ["ANY"] - ) + action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"]) action_destination_ip = ( - int(action_destination_ip) - if action_destination_ip != "ANY" - else action_destination_ip + int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip ) action_protocol = service_num + 1 # +1 as 0 is any # Randomly select a port to block diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index e258edb0..310fc178 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,9 +1,5 @@ from primaite.agents.agent import HardCodedAgentSessionABC -from primaite.agents.utils import ( - get_new_action, - transform_action_node_enum, - transform_change_obs_readable, -) +from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable class HardCodedNodeAgent(HardCodedAgentSessionABC): @@ -93,12 +89,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): if os_state == "OFF": action_node_id = x + 1 action_node_property = "OPERATING" - property_action = ( - "ON" # Why reset it when we can just turn it on - ) - action_service_index = ( - 0 # does nothing isn't relevant for operating state - ) + property_action = "ON" # Why reset it when we can just turn it on + action_service_index = 0 # does nothing isn't relevant for operating state action = [ action_node_id, action_node_property, diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 35ae1b53..2b6a5a83 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -3,9 +3,8 @@ from __future__ import annotations import json from datetime import datetime from pathlib import Path -from typing import Optional, Union +from typing import Union -import tensorflow as tf from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.ppo import PPOConfig @@ -14,11 +13,7 @@ from ray.tune.registry import register_env from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import ( - AgentFramework, - AgentIdentifier, - DeepLearningFramework, -) +from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) @@ -49,10 +44,7 @@ class RLlibAgent(AgentSessionABC): def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: - msg = ( - f"Expected RLLIB agent_framework, " - f"got {self._training_config.agent_framework}" - ) + msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -60,10 +52,7 @@ class RLlibAgent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_config_class = A2CConfig else: - msg = ( - "Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier.value}" - ) + msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" _LOGGER.error(msg) raise ValueError(msg) self._agent_config: PPOConfig @@ -94,12 +83,8 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result[ - "episodes_total" - ] - metadata_dict["total_time_steps"] = self._current_result[ - "timesteps_total" - ] + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -122,9 +107,7 @@ class RLlibAgent(AgentSessionABC): ), ) - self._agent_config.training( - train_batch_size=self._training_config.num_steps - ) + self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") self._agent_config.rollouts( @@ -132,72 +115,41 @@ class RLlibAgent(AgentSessionABC): num_envs_per_worker=1, horizon=self._training_config.num_steps, ) - self._agent: Algorithm = self._agent_config.build( - logger_creator=_custom_log_creator(self.session_path) - ) + self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or ( - episode_count == self._training_config.num_episodes - ): + if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): self._agent.save(str(self.checkpoints_path)) def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ - # Temporarily override train_batch_size and horizon - if time_steps: - self._agent_config.train_batch_size = time_steps - self._agent_config.horizon = time_steps + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes - if not episodes: - episodes = self._training_config.num_episodes - _LOGGER.info( - f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps..." - ) + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() - if ( - self._training_config.deep_learning_framework - != DeepLearningFramework.TORCH - ): - policy = self._agent.get_policy() - tf.compat.v1.summary.FileWriter( - self.session_path / "ray_results", policy.get_session().graph - ) - super().learn() self._agent.stop() + super().learn() def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 8d5dd633..3161c93a 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, Union +from typing import Union import numpy as np from stable_baselines3 import A2C, PPO @@ -21,10 +21,7 @@ class SB3Agent(AgentSessionABC): def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.SB3: - msg = ( - f"Expected SB3 agent_framework, " - f"got {self._training_config.agent_framework}" - ) + msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -32,10 +29,7 @@ class SB3Agent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_class = A2C else: - msg = ( - "Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier}" - ) + msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}" _LOGGER.error(msg) raise ValueError(msg) @@ -64,19 +58,15 @@ class SB3Agent(AgentSessionABC): self._env, verbose=self.output_verbose_level, n_steps=self._training_config.num_steps, - tensorboard_log=self._tensorboard_log_path, + tensorboard_log=str(self._tensorboard_log_path), ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or ( - episode_count == self._training_config.num_episodes - ): - checkpoint_path = ( - self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - ) + if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" self._agent.save(checkpoint_path) _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") @@ -85,58 +75,37 @@ class SB3Agent(AgentSessionABC): def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ - if not time_steps: - time_steps = self._training_config.num_steps - - if not episodes: - episodes = self._training_config.num_episodes + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes self.is_eval = False - _LOGGER.info( - f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps..." - ) + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() - - self.close() + self._env.reset() + self._env.close() super().learn() def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, deterministic: bool = True, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param deterministic: Whether the evaluation is deterministic. :param kwargs: Any agent-specific key-word args to be passed. """ - if not time_steps: - time_steps = self._training_config.num_steps - - if not episodes: - episodes = self._training_config.num_episodes + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes self._env.set_as_eval() self.is_eval = True if deterministic: @@ -144,19 +113,18 @@ class SB3Agent(AgentSessionABC): else: deterministic_str = "non-deterministic" _LOGGER.info( - f"Beginning {deterministic_str} evaluation for " - f"{episodes} episodes @ {time_steps} time steps..." + f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." ) for episode in range(episodes): obs = self._env.reset() for step in range(time_steps): - action, _states = self._agent.predict( - obs, deterministic=deterministic - ) + action, _states = self._agent.predict(obs, deterministic=deterministic) if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + self._env.reset() + self._env.close() super().evaluate() @classmethod diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index cf333b1e..5a6c9da5 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,9 +1,5 @@ from primaite.agents.agent import HardCodedAgentSessionABC -from primaite.agents.utils import ( - get_new_action, - transform_action_acl_enum, - transform_action_node_enum, -) +from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum class RandomAgent(HardCodedAgentSessionABC): diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index c3e67fdf..8c59faf7 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -24,9 +24,7 @@ def transform_action_node_readable(action): if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif ( - action_node_property == "OS" or action_node_property == "SERVICE" - ) and action[2] <= 1: + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -117,11 +115,7 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if ( - action_source_id == action_destination_id - and action_source_id != "ANY" - and action_destination_id != "ANY" - ): + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": # ACL rule towards itself return False if action_permission == "DENY": @@ -173,9 +167,7 @@ def transform_change_obs_readable(obs): for service in range(3, obs.shape[1]): # Links bit/s don't have a service state - service_states = [ - SoftwareState(i).name if i <= 4 else i for i in obs[:, service] - ] + service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] new_obs.append(service_states) return new_obs @@ -247,9 +239,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change( - obs1, obs2, num_nodes=10, num_links=10, num_services=1 -): +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): """ Return string describing change between two observations. @@ -291,16 +281,9 @@ def _describe_obs_change_helper(obs_change, is_link): TODO: Typehint params and return. """ # Indexes where a change has occured, not including 0th index - index_changed = [ - i for i in range(1, len(obs_change)) if obs_change[i] != -1 - ] + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] # Node pol types, Indexes >= 3 are service nodes - NodePOLTypes = [ - NodePOLType(i).name - if i < 3 - else NodePOLType(3).name + " " + str(i - 3) - for i in index_changed - ] + NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed] # Account for hardware states, software sattes and links states = [ LinkStatus(obs_change[i]).name @@ -367,9 +350,7 @@ def transform_action_node_readable(action): if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif ( - action_node_property == "OS" or action_node_property == "SERVICE" - ) and action[2] <= 1: + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -397,9 +378,7 @@ def node_action_description(action): if property_action == "NONE": return "" if node_property == "OPERATING" or node_property == "OS": - description = ( - f"NODE {node_id}, {node_property}, SET TO {property_action}" - ) + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" elif node_property == "SERVICE": description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" else: @@ -522,11 +501,7 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if ( - action_source_id == action_destination_id - and action_source_id != "ANY" - and action_destination_id != "ANY" - ): + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": # ACL rule towards itself return False if action_permission == "DENY": diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 0431174f..40e8cf0d 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -56,9 +56,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum( - "LogLevel", {k: k for k in logging._levelToName.values()} -) # noqa +_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa @app.command() @@ -124,21 +122,12 @@ def setup(overwrite_existing: bool = True): app_dirs = PlatformDirs(appname="primaite") app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - pkg_config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) - ) + pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) shutil.copy2(pkg_config_path, user_config_path) from primaite import getLogger - from primaite.setup import ( - old_installation_clean_up, - reset_demo_notebooks, - reset_example_configs, - setup_app_dirs, - ) + from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs _LOGGER = getLogger(__name__) @@ -188,9 +177,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): @app.command() -def plotly_template( - template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None -): +def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None): """ View or set the plotly template for Session plots. @@ -208,14 +195,10 @@ def plotly_template( primaite_config = yaml.safe_load(file) if template: - primaite_config["session"]["outputs"]["plots"][ - "template" - ] = template.value + primaite_config["session"]["outputs"]["plots"]["template"] = template.value with open(user_config_path, "w") as file: yaml.dump(primaite_config, file) print(f"PrimAITE plotly template: {template.value}") else: - template = primaite_config["session"]["outputs"]["plots"][ - "template" - ] + template = primaite_config["session"]["outputs"]["plots"]["template"] print(f"PrimAITE plotly template: {template}") diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index ae067228..08f77b2f 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -8,14 +8,10 @@ from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[Path] = ( - USERS_CONFIG_DIR / "example_config" / "lay_down" -) +_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" -def convert_legacy_lay_down_config_dict( - legacy_config_dict: Dict[str, Any] -) -> Dict[str, Any]: +def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]: """ Convert a legacy lay down config dict to the new format. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 84dd3cc8..3e0f26ca 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -20,9 +20,7 @@ from primaite.common.enums import ( _LOGGER = getLogger(__name__) -_EXAMPLE_TRAINING: Final[Path] = ( - USERS_CONFIG_DIR / "example_config" / "training" -) +_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" def main_training_config_path() -> Path: @@ -68,9 +66,7 @@ class TrainingConfig: checkpoint_every_n_episodes: int = 5 "The agent will save a checkpoint every n episodes" - observation_space: dict = field( - default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} - ) + observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}) "The observation space config dict" time_delay: int = 10 @@ -180,9 +176,7 @@ class TrainingConfig: "The time taken to scan the file system" @classmethod - def from_dict( - cls, config_dict: Dict[str, Union[str, int, bool]] - ) -> TrainingConfig: + def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: """ Create an instance of TrainingConfig from a dict. @@ -238,9 +232,7 @@ class TrainingConfig: return tc -def load( - file_path: Union[str, Path], legacy_file: bool = False -) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -271,10 +263,7 @@ def load( try: return TrainingConfig.from_dict(config) except TypeError as e: - msg = ( - f"Error when creating an instance of {TrainingConfig} " - f"from the training config file {file_path}" - ) + msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}" _LOGGER.critical(msg, exc_info=True) raise e msg = f"Cannot load the training config as it does not exist: {file_path}" @@ -314,9 +303,7 @@ def convert_legacy_training_config_dict( "output_verbose_level": output_verbose_level.name, } session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} - legacy_config_dict["sessionType"] = session_type_map[ - legacy_config_dict["sessionType"] - ] + legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]] for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 6893125e..d0d5d46e 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -77,9 +77,7 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeros( - observation_shape, dtype=self._DATA_TYPE - ) + self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) def update(self): """Update the observation based on current environment state. @@ -94,12 +92,8 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation[item_index][0] = int(node.node_id) self.current_observation[item_index][1] = node.hardware_state.value if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.current_observation[item_index][ - 2 - ] = node.software_state.value - self.current_observation[item_index][ - 3 - ] = node.file_system_state_observed.value + self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][3] = node.file_system_state_observed.value else: self.current_observation[item_index][2] = 0 self.current_observation[item_index][3] = 0 @@ -107,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent): if isinstance(node, ServiceNode): for service in self.env.services_list: if node.has_service(service): - self.current_observation[item_index][ - service_index - ] = node.get_service_state(service).value + self.current_observation[item_index][service_index] = node.get_service_state(service).value else: self.current_observation[item_index][service_index] = 0 service_index += 1 @@ -129,9 +121,7 @@ class NodeLinkTable(AbstractObservationComponent): protocol_list = link.get_protocol_list() protocol_index = 0 for protocol in protocol_list: - self.current_observation[item_index][ - protocol_index + 4 - ] = protocol.get_load() + self.current_observation[item_index][protocol_index + 4] = protocol.get_load() protocol_index += 1 item_index += 1 @@ -203,9 +193,7 @@ class NodeStatuses(AbstractObservationComponent): if isinstance(node, ServiceNode): for i, service in enumerate(self.env.services_list): if node.has_service(service): - service_states[i] = node.get_service_state( - service - ).value + service_states[i] = node.get_service_state(service).value obs.extend( [ hardware_state, @@ -269,11 +257,7 @@ class LinkTrafficLevels(AbstractObservationComponent): self._entries_per_link = self.env.num_services # 1. Define the shape of your observation space component - shape = ( - [self._quantisation_levels] - * self.env.num_links - * self._entries_per_link - ) + shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) @@ -292,9 +276,7 @@ class LinkTrafficLevels(AbstractObservationComponent): if self._combine_service_traffic: loads = [link.get_current_load()] else: - loads = [ - protocol.get_load() for protocol in link.protocol_list - ] + loads = [protocol.get_load() for protocol in link.protocol_list] for load in loads: if load <= 0: @@ -302,9 +284,7 @@ class LinkTrafficLevels(AbstractObservationComponent): elif load >= bandwidth: traffic_level = self._quantisation_levels - 1 else: - traffic_level = (load / bandwidth) // ( - 1 / (self._quantisation_levels - 2) - ) + 1 + traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1 obs.append(int(traffic_level)) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index ea8f82d4..df51e21e 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -12,13 +12,11 @@ from matplotlib import pyplot as plt from primaite import getLogger from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import ( - is_valid_acl_action_extra, - is_valid_node_action, -) +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, + AgentFramework, FileSystemState, HardwareState, NodePOLInitiator, @@ -37,18 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import ( - NodeStateInstructionGreen, -) +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import ( - apply_red_agent_iers, - apply_red_agent_node_pol, -) +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction from primaite.utils.session_output_writer import SessionOutputWriter @@ -85,9 +78,7 @@ class Primaite(Env): self._training_config_path = training_config_path self._lay_down_config_path = lay_down_config_path - self.training_config: TrainingConfig = training_config.load( - training_config_path - ) + self.training_config: TrainingConfig = training_config.load(training_config_path) _LOGGER.info(f"Using: {str(self.training_config)}") # Number of steps in an episode @@ -238,25 +229,22 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.error( - f"Invalid action type selected: {self.training_config.action_type}" - ) + _LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}") - self.episode_av_reward_writer = SessionOutputWriter( - self, transaction_writer=False, learning_session=True - ) - self.transaction_writer = SessionOutputWriter( - self, transaction_writer=True, learning_session=True - ) + self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True) + self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True) + + @property + def actual_episode_count(self) -> int: + """Shifts the episode_count by -1 for RLlib.""" + if self.training_config.agent_framework is AgentFramework.RLLIB: + return self.episode_count - 1 + return self.episode_count def set_as_eval(self): """Set the writers to write to eval directories.""" - self.episode_av_reward_writer = SessionOutputWriter( - self, transaction_writer=False, learning_session=False - ) - self.transaction_writer = SessionOutputWriter( - self, transaction_writer=True, learning_session=False - ) + self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False) + self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False) self.episode_count = 0 self.step_count = 0 self.total_step_count = 0 @@ -268,8 +256,8 @@ class Primaite(Env): Returns: Environment observation space (reset) """ - if self.episode_count > 0: - csv_data = self.episode_count, self.average_reward + if self.actual_episode_count > 0: + csv_data = self.actual_episode_count, self.average_reward self.episode_av_reward_writer.write(csv_data) self.episode_count += 1 @@ -291,6 +279,7 @@ class Primaite(Env): # Update observations space and return self.update_environent_obs() + return self.env_obs def step(self, action): @@ -319,9 +308,7 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction( - self.agent_identifier, self.episode_count, self.step_count - ) + transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) # Load the initial observation space into the transaction transaction.obs_space_pre = copy.deepcopy(self.env_obs) # Load the action space into the transaction @@ -350,9 +337,7 @@ class Primaite(Env): self.nodes_post_pol = copy.deepcopy(self.nodes) self.links_post_pol = copy.deepcopy(self.links) # Reference - apply_node_pol( - self.nodes_reference, self.node_pol, self.step_count - ) # Node PoL + apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL apply_iers( self.network_reference, self.nodes_reference, @@ -371,9 +356,7 @@ class Primaite(Env): self.acl, self.step_count, ) - apply_red_agent_node_pol( - self.nodes, self.red_iers, self.red_node_pol, self.step_count - ) + apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count) # Take snapshots of nodes and links self.nodes_post_red = copy.deepcopy(self.nodes) self.links_post_red = copy.deepcopy(self.links) @@ -389,11 +372,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - _LOGGER.debug( - f"Episode: {self.episode_count}, " - f"Step {self.step_count}, " - f"Reward: {reward}" - ) + _LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -401,10 +380,7 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - _LOGGER.info( - f"Episode: {self.episode_count}, " - f"Average Reward: {self.average_reward}" - ) + _LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}") # Load the reward into the transaction transaction.reward = reward @@ -417,11 +393,21 @@ class Primaite(Env): transaction.obs_space_post = copy.deepcopy(self.env_obs) # Write transaction to file - self.transaction_writer.write(transaction) + if self.actual_episode_count > 0: + self.transaction_writer.write(transaction) # Return return self.env_obs, reward, done, self.step_info + def close(self): + """Override parent close and close writers.""" + # Close files if last episode/step + # if self.can_finish: + super().close() + + self.transaction_writer.close() + self.episode_av_reward_writer.close() + def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() @@ -431,12 +417,7 @@ class Primaite(Env): for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - _LOGGER.debug( - " Protocol: " - + protocol.get_name().name - + ", Load: " - + str(protocol.get_load()) - ) + _LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): """ @@ -450,13 +431,9 @@ class Primaite(Env): self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 6 - ): # ACL actions in multidiscrete form have len 6 + elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 4 - ): # Node actions in multdiscrete (array) from have len 4 + elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: _LOGGER.error("Invalid action type found") @@ -541,10 +518,7 @@ class Primaite(Env): elif property_action == 2: # Repair # You cannot repair a destroyed file system - it needs restoring - if ( - node.file_system_state_actual - != FileSystemState.DESTROYED - ): + if node.file_system_state_actual != FileSystemState.DESTROYED: node.set_file_system_state(FileSystemState.REPAIRING) elif property_action == 3: # Restore @@ -587,9 +561,7 @@ class Primaite(Env): acl_rule_source = "ANY" else: node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance( - node, ActiveNode - ): + if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): acl_rule_source = node.ip_address else: return @@ -598,9 +570,7 @@ class Primaite(Env): acl_rule_destination = "ANY" else: node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance( - node, ActiveNode - ): + if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): acl_rule_destination = node.ip_address else: return @@ -685,9 +655,7 @@ class Primaite(Env): :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] """ - self.obs_handler = ObservationsHandler.from_config( - self, self.obs_config - ) + self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation @@ -794,9 +762,7 @@ class Primaite(Env): service_protocol = service["name"] service_port = service["port"] service_state = SoftwareState[service["state"]] - node.add_service( - Service(service_protocol, service_port, service_state) - ) + node.add_service(Service(service_protocol, service_port, service_state)) else: # Bad formatting pass @@ -849,9 +815,7 @@ class Primaite(Env): dest_node_ref: Node = self.nodes_reference[link_destination] # Add link to network (reference) - self.network_reference.add_edge( - source_node_ref, dest_node_ref, id=link_name - ) + self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name) # Add link to link dictionary (reference) self.links_reference[link_name] = Link( @@ -1126,9 +1090,7 @@ class Primaite(Env): # All nodes have these parameters node_id = item["node_id"] node_class = item["node_class"] - node_hardware_state: HardwareState = HardwareState[ - item["hardware_state"] - ] + node_hardware_state: HardwareState = HardwareState[item["hardware_state"]] node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] @@ -1249,11 +1211,7 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = { - k + len(acl_action_dict) - 1: v - for k, v in node_action_dict.items() - if k != 0 - } + new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 4dd0550e..19094a18 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -41,29 +41,19 @@ def calculate_reward_function( reference_node = reference_nodes[node_key] # Hardware State - reward_value += score_node_operating_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values) # Software State - if isinstance(final_node, ActiveNode) or isinstance( - final_node, ServiceNode - ): - reward_value += score_node_os_state( - final_node, initial_node, reference_node, config_values - ) + if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode): + reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values) # Service State if isinstance(final_node, ServiceNode): - reward_value += score_node_service_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values) # File System State if isinstance(final_node, ActiveNode): - reward_value += score_node_file_system( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values) # Go through each red IER - penalise if it is running for ier_key, ier_value in red_iers.items(): @@ -82,10 +72,7 @@ def calculate_reward_function( if step_count >= start_step and step_count <= stop_step: reference_blocked = not reference_ier.get_is_running() live_blocked = not ier_value.get_is_running() - ier_reward = ( - config_values.green_ier_blocked - * ier_value.get_mission_criticality() - ) + ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality() if live_blocked and not reference_blocked: reward_value += ier_reward @@ -107,9 +94,7 @@ def calculate_reward_function( return reward_value -def score_node_operating_state( - final_node, initial_node, reference_node, config_values -): +def score_node_operating_state(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the hardware state of a node. @@ -158,9 +143,7 @@ def score_node_operating_state( return score -def score_node_os_state( - final_node, initial_node, reference_node, config_values -): +def score_node_os_state(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the Software State of a node. @@ -211,9 +194,7 @@ def score_node_os_state( return score -def score_node_service_state( - final_node, initial_node, reference_node, config_values -): +def score_node_service_state(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the service state(s) of a node. @@ -285,9 +266,7 @@ def score_node_service_state( return score -def score_node_file_system( - final_node, initial_node, reference_node, config_values -): +def score_node_file_system(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the file system state of a node. diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 054f4c34..90235e9f 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -8,9 +8,7 @@ from primaite.common.protocol import Protocol class Link(object): """Link class.""" - def __init__( - self, _id, _bandwidth, _source_node_name, _dest_node_name, _services - ): + def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): """ Init. diff --git a/src/primaite/main.py b/src/primaite/main.py index 556c5ec3..7b1d7ab3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -32,11 +32,7 @@ if __name__ == "__main__": parser.add_argument("--ldc") args = parser.parse_args() if not args.tc: - _LOGGER.error( - "Please provide a training config file using the --tc " "argument" - ) + _LOGGER.error("Please provide a training config file using the --tc " "argument") if not args.ldc: - _LOGGER.error( - "Please provide a lay down config file using the --ldc " "argument" - ) + _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index b1c3f57c..07a0ea0a 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -3,13 +3,7 @@ import logging from typing import Final -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.config.training_config import TrainingConfig from primaite.nodes.node import Node @@ -44,9 +38,7 @@ class ActiveNode(Node): :param file_system_state: The node file system state :param config_values: The config values """ - super().__init__( - node_id, name, node_type, priority, hardware_state, config_values - ) + super().__init__(node_id, name, node_type, priority, hardware_state, config_values) self.ip_address: str = ip_address # Related to Software self._software_state: SoftwareState = software_state @@ -87,9 +79,7 @@ class ActiveNode(Node): f"Node.software_state:{self._software_state}" ) - def set_software_state_if_not_compromised( - self, software_state: SoftwareState - ): + def set_software_state_if_not_compromised(self, software_state: SoftwareState): """ Sets Software State if the node is not compromised. @@ -100,9 +90,7 @@ class ActiveNode(Node): if self._software_state != SoftwareState.COMPROMISED: self._software_state = software_state if software_state == SoftwareState.PATCHING: - self.patching_count = ( - self.config_values.os_patching_duration - ) + self.patching_count = self.config_values.os_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so OS State cannot be changed." @@ -129,14 +117,10 @@ class ActiveNode(Node): self.file_system_state_actual = file_system_state if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = ( - self.config_values.file_system_repairing_limit - ) + self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_state_observed = FileSystemState.REPAIRING elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = ( - self.config_values.file_system_restoring_limit - ) + self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_state_observed = FileSystemState.RESTORING elif file_system_state == FileSystemState.GOOD: self.file_system_state_observed = FileSystemState.GOOD @@ -149,9 +133,7 @@ class ActiveNode(Node): f"Node.file_system_state.actual:{self.file_system_state_actual}" ) - def set_file_system_state_if_not_compromised( - self, file_system_state: FileSystemState - ): + def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState): """ Sets the file system state (actual and observed) if not in a compromised state. @@ -168,14 +150,10 @@ class ActiveNode(Node): self.file_system_state_actual = file_system_state if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = ( - self.config_values.file_system_repairing_limit - ) + self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_state_observed = FileSystemState.REPAIRING elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = ( - self.config_values.file_system_restoring_limit - ) + self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_state_observed = FileSystemState.RESTORING elif file_system_state == FileSystemState.GOOD: self.file_system_state_observed = FileSystemState.GOOD @@ -191,9 +169,7 @@ class ActiveNode(Node): def start_file_system_scan(self): """Starts a file system scan.""" self.file_system_scanning = True - self.file_system_scanning_count = ( - self.config_values.file_system_scanning_limit - ) + self.file_system_scanning_count = self.config_values.file_system_scanning_limit def update_file_system_state(self): """Updates file system status based on scanning/restore/repair cycle.""" @@ -212,10 +188,7 @@ class ActiveNode(Node): self.file_system_state_observed = FileSystemState.GOOD # Scanning updates - if ( - self.file_system_scanning == True - and self.file_system_scanning_count < 0 - ): + if self.file_system_scanning == True and self.file_system_scanning_count < 0: self.file_system_state_observed = self.file_system_state_actual self.file_system_scanning = False self.file_system_scanning_count = 0 diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 04681807..2b1d94be 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -32,9 +32,7 @@ class NodeStateInstructionGreen(object): self.end_step = _end_step self.node_id = _node_id self.node_pol_type = _node_pol_type - self.service_name = ( - _service_name # Not used when not a service instruction - ) + self.service_name = _service_name # Not used when not a service instruction self.state = _state def get_start_step(self): diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index ba35067c..7f62fe24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -42,9 +42,7 @@ class NodeStateInstructionRed(object): self.target_node_id = _target_node_id self.initiator = _pol_initiator self.pol_type: NodePOLType = _pol_type - self.service_name = ( - pol_protocol # Not used when not a service instruction - ) + self.service_name = pol_protocol # Not used when not a service instruction self.state = _pol_state self.source_node_id = _pol_source_node_id self.source_node_service = _pol_source_node_service diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index 6515097a..9aa5c7d7 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -28,9 +28,7 @@ class PassiveNode(Node): :param config_values: Config values. """ # Pass through to Super for now - super().__init__( - node_id, name, node_type, priority, hardware_state, config_values - ) + super().__init__(node_id, name, node_type, priority, hardware_state, config_values) @property def ip_address(self) -> str: diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 6dcff73e..5d69df92 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -3,13 +3,7 @@ import logging from typing import Dict, Final -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -110,9 +104,7 @@ class ServiceNode(ActiveNode): return False return False - def set_service_state( - self, protocol_name: str, software_state: SoftwareState - ): + def set_service_state(self, protocol_name: str, software_state: SoftwareState): """ Sets the software_state of a service (protocol) on the node. @@ -130,9 +122,7 @@ class ServiceNode(ActiveNode): ) or software_state != SoftwareState.COMPROMISED: service_value.software_state = software_state if software_state == SoftwareState.PATCHING: - service_value.patching_count = ( - self.config_values.service_patching_duration - ) + service_value.patching_count = self.config_values.service_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so the state of a service " @@ -143,9 +133,7 @@ class ServiceNode(ActiveNode): f"Node.services[].software_state:{software_state}" ) - def set_service_state_if_not_compromised( - self, protocol_name: str, software_state: SoftwareState - ): + def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState): """ Sets the software_state of a service (protocol) on the node. @@ -161,9 +149,7 @@ class ServiceNode(ActiveNode): if service_value.software_state != SoftwareState.COMPROMISED: service_value.software_state = software_state if software_state == SoftwareState.PATCHING: - service_value.patching_count = ( - self.config_values.service_patching_duration - ) + service_value.patching_count = self.config_values.service_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so the state of a service " diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index aeae7add..e9dfef8c 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -6,17 +6,10 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( - HardwareState, - NodePOLType, - NodeType, - SoftwareState, -) +from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node_state_instruction_green import ( - NodeStateInstructionGreen, -) +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER @@ -93,9 +86,7 @@ def apply_iers( and source_node.software_state != SoftwareState.PATCHING ): if source_node.has_service(protocol): - if source_node.service_running( - protocol - ) and not source_node.service_is_overwhelmed(protocol): + if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol): source_valid = True else: source_valid = False @@ -110,10 +101,7 @@ def apply_iers( # 2. Check the dest node situation if dest_node.node_type == NodeType.SWITCH: # It's a switch - if ( - dest_node.hardware_state == HardwareState.ON - and dest_node.software_state != SoftwareState.PATCHING - ): + if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: dest_valid = True else: # IER no longer valid @@ -123,14 +111,9 @@ def apply_iers( pass else: # It's not a switch or an actuator (so active node) - if ( - dest_node.hardware_state == HardwareState.ON - and dest_node.software_state != SoftwareState.PATCHING - ): + if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: if dest_node.has_service(protocol): - if dest_node.service_running( - protocol - ) and not dest_node.service_is_overwhelmed(protocol): + if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol): dest_valid = True else: dest_valid = False @@ -143,9 +126,7 @@ def apply_iers( dest_valid = False # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked( - source_node.ip_address, dest_node.ip_address, protocol, port - ) + acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) if acl_block: if _VERBOSE: print( @@ -176,10 +157,7 @@ def apply_iers( # We might have a switch in the path, so check all nodes are operational for node in path_node_list: - if ( - node.hardware_state != HardwareState.ON - or node.software_state == SoftwareState.PATCHING - ): + if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING: path_valid = False if path_valid: @@ -191,15 +169,11 @@ def apply_iers( # Check that the link capacity is not exceeded by the new load while count < path_node_list_length - 1: # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] - ) + edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if ( - link.get_current_load() + load - ) > link.get_bandwidth(): + if (link.get_current_load() + load) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -226,9 +200,7 @@ def apply_iers( else: # One of the nodes is not operational if _VERBOSE: - print( - "Path not valid - one or more nodes not operational" - ) + print("Path not valid - one or more nodes not operational") pass else: @@ -243,9 +215,7 @@ def apply_iers( def apply_node_pol( nodes: Dict[str, NodeUnion], - node_pol: Dict[ - any, Union[NodeStateInstructionGreen, NodeStateInstructionRed] - ], + node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], step: int, ): """ @@ -277,22 +247,16 @@ def apply_node_pol( elif node_pol_type == NodePOLType.OS: # Change OS state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this - if isinstance(node, ActiveNode) or isinstance( - node, ServiceNode - ): + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): node.set_software_state_if_not_compromised(state) elif node_pol_type == NodePOLType.SERVICE: # Change a service state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this if isinstance(node, ServiceNode): - node.set_service_state_if_not_compromised( - service_name, state - ) + node.set_service_state_if_not_compromised(service_name, state) else: # Change the file system status - if isinstance(node, ActiveNode) or isinstance( - node, ServiceNode - ): + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): node.set_file_system_state_if_not_compromised(state) else: # PoL is not valid in this time step diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index 96fe787c..bff19bf8 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( - HardwareState, - NodePOLInitiator, - NodePOLType, - NodeType, - SoftwareState, -) +from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed @@ -83,10 +77,7 @@ def apply_red_agent_iers( if source_node.hardware_state == HardwareState.ON: if source_node.has_service(protocol): # Red agents IERs can only be valid if the source service is in a compromised state - if ( - source_node.get_service_state(protocol) - == SoftwareState.COMPROMISED - ): + if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED: source_valid = True else: source_valid = False @@ -124,9 +115,7 @@ def apply_red_agent_iers( dest_valid = False # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked( - source_node.ip_address, dest_node.ip_address, protocol, port - ) + acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) if acl_block: if _VERBOSE: print( @@ -170,15 +159,11 @@ def apply_red_agent_iers( # Check that the link capacity is not exceeded by the new load while count < path_node_list_length - 1: # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] - ) + edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if ( - link.get_current_load() + load - ) > link.get_bandwidth(): + if (link.get_current_load() + load) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -203,23 +188,16 @@ def apply_red_agent_iers( # This IER is now valid, so set it to running ier_value.set_is_running(True) if _VERBOSE: - print( - "Red IER was allowed to run in step " - + str(step) - ) + print("Red IER was allowed to run in step " + str(step)) else: # One of the nodes is not operational if _VERBOSE: - print( - "Path not valid - one or more nodes not operational" - ) + print("Path not valid - one or more nodes not operational") pass else: if _VERBOSE: - print( - "Red IER was NOT allowed to run in step " + str(step) - ) + print("Red IER was NOT allowed to run in step " + str(step)) print("Source, Dest or ACL were not valid") pass # ------------------------------------ @@ -258,9 +236,7 @@ def apply_red_agent_node_pol( state = node_instruction.get_state() source_node_id = node_instruction.get_source_node_id() source_node_service_name = node_instruction.get_source_node_service() - source_node_service_state_value = ( - node_instruction.get_source_node_service_state() - ) + source_node_service_state_value = node_instruction.get_source_node_service_state() passed_checks = False @@ -274,9 +250,7 @@ def apply_red_agent_node_pol( passed_checks = True elif initiator == NodePOLInitiator.IER: # Need to check there is a red IER incoming - passed_checks = is_red_ier_incoming( - target_node, iers, pol_type - ) + passed_checks = is_red_ier_incoming(target_node, iers, pol_type) elif initiator == NodePOLInitiator.SERVICE: # Need to check the condition of a service on another node source_node = nodes[source_node_id] @@ -304,9 +278,7 @@ def apply_red_agent_node_pol( target_node.hardware_state = state elif pol_type == NodePOLType.OS: # Change OS state - if isinstance(target_node, ActiveNode) or isinstance( - target_node, ServiceNode - ): + if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.software_state = state elif pol_type == NodePOLType.SERVICE: # Change a service state @@ -314,15 +286,11 @@ def apply_red_agent_node_pol( target_node.set_service_state(service_name, state) else: # Change the file system status - if isinstance(target_node, ActiveNode) or isinstance( - target_node, ServiceNode - ): + if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.set_file_system_state(state) else: if _VERBOSE: - print( - "Node Red Agent PoL not allowed - did not pass checks" - ) + print("Node Red Agent PoL not allowed - did not pass checks") else: # PoL is not valid in this time step pass @@ -337,10 +305,7 @@ def is_red_ier_incoming(node, iers, node_pol_type): node_id = node.node_id for ier_key, ier_value in iers.items(): - if ( - ier_value.get_is_running() - and ier_value.get_dest_node_id() == node_id - ): + if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id: if ( node_pol_type == NodePOLType.OPERATING or node_pol_type == NodePOLType.OS diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 4d8d3022..4ee6c507 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Dict, Final, Optional, Union +from typing import Dict, Final, Union from primaite import getLogger from primaite.agents.agent import AgentSessionABC @@ -9,18 +9,8 @@ from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent -from primaite.agents.simple import ( - DoNothingACLAgent, - DoNothingNodeAgent, - DummyAgent, - RandomAgent, -) -from primaite.common.enums import ( - ActionType, - AgentFramework, - AgentIdentifier, - SessionType, -) +from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent +from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig @@ -49,16 +39,12 @@ class PrimaiteSession: if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load( - self._training_config_path - ) + self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load( - self._lay_down_config_path - ) + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) self._agent_session: AgentSessionABC = None # noqa self.session_path: Path = None # noqa @@ -69,28 +55,16 @@ class PrimaiteSession: def setup(self): """Performs the session setup.""" if self._training_config.agent_framework == AgentFramework.CUSTOM: - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}" - ) - if ( - self._training_config.agent_identifier - == AgentIdentifier.HARDCODED - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.HARDCODED}" - ) + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}") + if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - self._agent_session = HardCodedNodeAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = HardCodedACLAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -100,24 +74,14 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif ( - self._training_config.agent_identifier - == AgentIdentifier.DO_NOTHING - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.DO_NOTHINGD}" - ) + elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHINGD}") if self._training_config.action_type == ActionType.NODE: - self._agent_session = DoNothingNodeAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = DoNothingACLAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -127,49 +91,26 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif ( - self._training_config.agent_identifier - == AgentIdentifier.RANDOM - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.RANDOM}" - ) - self._agent_session = RandomAgent( - self._training_config_path, self._lay_down_config_path - ) - elif ( - self._training_config.agent_identifier == AgentIdentifier.DUMMY - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.DUMMY}" - ) - self._agent_session = DummyAgent( - self._training_config_path, self._lay_down_config_path - ) + elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") + self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path) + elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") + self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path) else: # Invalid AgentFramework AgentIdentifier combo raise ValueError elif self._training_config.agent_framework == AgentFramework.SB3: - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}" - ) + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") # Stable Baselines3 Agent - self._agent_session = SB3Agent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path) elif self._training_config.agent_framework == AgentFramework.RLLIB: - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}" - ) + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") # Ray RLlib Agent - self._agent_session = RLlibAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path) else: # Invalid AgentFramework @@ -182,35 +123,27 @@ class PrimaiteSession: def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of time steps per episode. - :param episodes: The number of episodes. :param kwargs: Any agent-framework specific key word args. """ if not self._training_config.session_type == SessionType.EVAL: - self._agent_session.learn(time_steps, episodes, **kwargs) + self._agent_session.learn(**kwargs) def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of time steps per episode. - :param episodes: The number of episodes. :param kwargs: Any agent-framework specific key word args. """ if not self._training_config.session_type == SessionType.TRAIN: - self._agent_session.evaluate(time_steps, episodes, **kwargs) + self._agent_session.evaluate(**kwargs) def close(self): """Closes the agent.""" diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 59eaf8cc..7fa96783 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -18,23 +18,17 @@ def run(overwrite_existing: bool = True): :param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off. """ - notebooks_package_data_root = pkg_resources.resource_filename( - "primaite", "notebooks/_package_data" - ) + notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data") for subdir, dirs, files in os.walk(notebooks_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath( - fp, notebooks_package_data_root - ).split(os.sep) + path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep) target_fp = NOTEBOOKS_DIR / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() if overwrite_existing and not copy_file: - copy_file = (not filecmp.cmp(fp, target_fp)) and ( - ".ipynb_checkpoints" not in str(target_fp) - ) + copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp)) if copy_file: shutil.copy2(fp, target_fp) diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index f2b4a18f..5d62298c 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -17,16 +17,12 @@ def run(overwrite_existing=True): :param overwrite_existing: A bool to toggle replacing existing edited config on or off. """ - configs_package_data_root = pkg_resources.resource_filename( - "primaite", "config/_package_data" - ) + configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data") for subdir, dirs, files in os.walk(configs_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath(fp, configs_package_data_root).split( - os.sep - ) + path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep) target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 1a71f0ff..eeafe05e 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -76,12 +76,8 @@ class Transaction(object): row = ( row + _turn_action_space_to_array(self.action_space) - + _turn_obs_space_to_array( - self.obs_space_pre, obs_assets, obs_features - ) - + _turn_obs_space_to_array( - self.obs_space_post, obs_assets, obs_features - ) + + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features) + + _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features) ) return header, row diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 86c5ca28..a05b0453 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -51,9 +51,7 @@ class SessionOutputWriter: self._first_write: bool = True def _init_csv_writer(self): - self._csv_file = open( - self._csv_file_path, "w", encoding="UTF8", newline="" - ) + self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="") self._csv_writer = csv.writer(self._csv_file) diff --git a/tests/config/legacy/legacy_training_config.yaml b/tests/config/legacy_conversion/legacy_training_config.yaml similarity index 100% rename from tests/config/legacy/legacy_training_config.yaml rename to tests/config/legacy_conversion/legacy_training_config.yaml diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml similarity index 100% rename from tests/config/legacy/new_training_config.yaml rename to tests/config/legacy_conversion/new_training_config.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 41dc5e77..af76b314 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,8 +57,6 @@ class TempPrimaiteSession(PrimaiteSession): return self def __exit__(self, type, value, tb): - del self._agent_session._env.episode_av_reward_writer - del self._agent_session._env.transaction_writer shutil.rmtree(self.session_path) shutil.rmtree(self.session_path.parent) _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") @@ -112,9 +110,7 @@ def temp_primaite_session(request): """ training_config_path = request.param[0] lay_down_config_path = request.param[1] - with patch( - "primaite.agents.agent.get_session_path", get_temp_session_path - ) as mck: + with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck: mck.session_timestamp = datetime.now() return TempPrimaiteSession(training_config_path, lay_down_config_path) @@ -130,9 +126,7 @@ def temp_session_path() -> Path: session_timestamp = datetime.now() date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = ( - Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path - ) + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index cfcfb8f0..feff52f6 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -16,9 +16,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path: """ date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = ( - Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path - ) + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created temp session directory: {session_path}") return session_path diff --git a/tests/test_acl.py b/tests/test_acl.py index 260ccffc..30f12697 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -95,8 +95,6 @@ def test_rule_hash(): rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") hash_value_local = hash(rule) - hash_value_remote = acl.get_dictionary_hash( - "DENY", "192.168.1.1", "192.168.1.2", "TCP", "80" - ) + hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") assert hash_value_local == hash_value_remote diff --git a/tests/test_active_node.py b/tests/test_active_node.py index b6833182..addc595c 100644 --- a/tests/test_active_node.py +++ b/tests/test_active_node.py @@ -60,9 +60,7 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state): 1, ) - active_node.set_software_state_if_not_compromised( - SoftwareState.OVERWHELMED - ) + active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED) assert active_node.software_state == expected_state @@ -100,9 +98,7 @@ def test_file_system_change(operating_state, expected_state): (HardwareState.ON, FileSystemState.CORRUPT), ], ) -def test_file_system_change_if_not_compromised( - operating_state, expected_state -): +def test_file_system_change_if_not_compromised(operating_state, expected_state): """ Test that a node cannot change its file system state. @@ -120,8 +116,6 @@ def test_file_system_change_if_not_compromised( 1, ) - active_node.set_file_system_state_if_not_compromised( - FileSystemState.CORRUPT - ) + active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT) assert active_node.file_system_state_actual == expected_state diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 21e4857f..d1082049 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -2,11 +2,7 @@ import numpy as np import pytest -from primaite.environment.observations import ( - NodeLinkTable, - NodeStatuses, - ObservationsHandler, -) +from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler from tests import TEST_CONFIG_ROOT @@ -127,9 +123,7 @@ class TestNodeLinkTable: with temp_primaite_session as session: env = session.env # act = np.asarray([0,]) - obs, reward, done, info = env.step( - 0 - ) # apply the 'do nothing' action + obs, reward, done, info = env.step(0) # apply the 'do nothing' action assert np.array_equal( obs, @@ -192,17 +186,15 @@ class TestNodeStatuses: with temp_primaite_session as session: env = session.env obs, _, _, _ = env.step(0) # apply the 'do nothing' action - assert np.array_equal( - obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0] - ) + print(obs) + assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) @pytest.mark.parametrize( "temp_primaite_session", [ [ - TEST_CONFIG_ROOT - / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ] ], diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index 8c8d2b80..ae0b0870 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -36,18 +36,12 @@ def test_primaite_session(temp_primaite_session): # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv": - assert ( - "all_transactions" in file.name - or "average_reward_per_episode" in file.name - ) + assert "all_transactions" in file.name or "average_reward_per_episode" in file.name # Check that both the transactions and av reward csv files exist for file in session.evaluation_path.iterdir(): if file.suffix == ".csv": - assert ( - "all_transactions" in file.name - or "average_reward_per_episode" in file.name - ) + assert "all_transactions" in file.name or "average_reward_per_episode" in file.name _LOGGER.debug("Inspecting files in temp session path...") for dir_path, dir_names, file_names in os.walk(session_path): diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index e7312777..fb7dc83d 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,13 +1,7 @@ """Used to test Active Node functions.""" import pytest -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -18,9 +12,7 @@ from primaite.nodes.service_node import ServiceNode "starting_operating_state, expected_operating_state", [(HardwareState.RESETTING, HardwareState.ON)], ) -def test_node_resets_correctly( - starting_operating_state, expected_operating_state -): +def test_node_resets_correctly(starting_operating_state, expected_operating_state): """Tests that a node resets correctly.""" active_node = ActiveNode( node_id="0", @@ -59,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state): file_system_state="GOOD", config_values=1, ) - service_attributes = Service( - name="node", port="80", software_state=SoftwareState.COMPROMISED - ) + service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED) service_node.add_service(service_attributes) for x in range(5): diff --git a/tests/test_service_node.py b/tests/test_service_node.py index 9e760b23..4383fc1b 100644 --- a/tests/test_service_node.py +++ b/tests/test_service_node.py @@ -45,9 +45,7 @@ def test_service_state_change(operating_state, expected_state): (HardwareState.ON, SoftwareState.OVERWHELMED), ], ) -def test_service_state_change_if_not_comprised( - operating_state, expected_state -): +def test_service_state_change_if_not_comprised(operating_state, expected_state): """ Test that a node cannot change the state of a running service. @@ -67,8 +65,6 @@ def test_service_state_change_if_not_comprised( service = Service("TCP", 80, SoftwareState.GOOD) service_node.add_service(service) - service_node.set_service_state_if_not_compromised( - "TCP", SoftwareState.OVERWHELMED - ) + service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED) assert service_node.get_service_state("TCP") == expected_state diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 1cf63cde..5d55b9c9 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -18,7 +18,6 @@ def run_generic_set_actions(env: Primaite): # TEMP - random action for now # action = env.blue_agent_action(obs) action = 0 - print("Episode:", episode, "\nStep:", step) if step == 5: # [1, 1, 2, 1, 1, 1] # Creates an ACL rule @@ -86,8 +85,7 @@ def test_single_action_space_is_valid(temp_primaite_session): "temp_primaite_session", [ [ - TEST_CONFIG_ROOT - / "single_action_space_fixed_blue_actions_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ] ], diff --git a/tests/test_training_config.py b/tests/test_training_config.py index 88bc802b..d7fe4e50 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT def test_legacy_lay_down_config_yaml_conversion(): """Tests the conversion of legacy lay down config files.""" - legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" - new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml" with open(legacy_path, "r") as file: legacy_dict = yaml.safe_load(file) @@ -16,9 +16,7 @@ def test_legacy_lay_down_config_yaml_conversion(): with open(new_path, "r") as file: new_dict = yaml.safe_load(file) - converted_dict = training_config.convert_legacy_training_config_dict( - legacy_dict - ) + converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict) for key, value in new_dict.items(): assert converted_dict[key] == value @@ -26,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion(): def test_create_config_values_main_from_file(): """Tests creating an instance of TrainingConfig from file.""" - new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml" training_config.load(new_path) def test_create_config_values_main_from_legacy_file(): """Tests creating an instance of TrainingConfig from legacy file.""" - new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml" training_config.load(new_path, legacy_file=True)