#917 - Fixed the RLlib integration

- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
Chris McCarthy
2023-06-30 16:52:57 +01:00
parent 00185d3dad
commit e11fd2ced4
43 changed files with 284 additions and 896 deletions

View File

@@ -13,7 +13,7 @@ repos:
rev: 23.1.0
hooks:
- id: black
args: [ "--line-length=79" ]
args: [ "--line-length=120" ]
additional_dependencies:
- jupyter
- repo: http://github.com/pycqa/isort

View File

@@ -72,9 +72,9 @@ primaite = "primaite.cli:app"
[tool.isort]
profile = "black"
line_length = 79
line_length = 120
force_sort_within_sections = "False"
order_by_type = "False"
[tool.black]
line-length = 79
line-length = 120

View File

@@ -19,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
def _get_primaite_config():
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
if not config_path.exists():
config_path = Path(
pkg_resources.resource_filename(
"primaite", "setup/_package_data/primaite_config.yaml"
)
)
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
with open(config_path, "r") as file:
primaite_config = yaml.safe_load(file)
log_level_map = {
@@ -34,9 +30,7 @@ def _get_primaite_config():
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[
primaite_config["logging"]["log_level"]
]
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
return primaite_config
@@ -82,14 +76,9 @@ class _LevelFormatter(Formatter):
super().__init__()
if "fmt" in kwargs:
raise ValueError(
"Format string must be passed to level-surrogate formatters, "
"not this one"
)
raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one")
self.formats = sorted(
(level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()
)
self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items())
def format(self, record: LogRecord) -> str:
"""Overrides ``Formatter.format``."""
@@ -110,13 +99,9 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][
"WARNING"
],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][
"CRITICAL"
],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
}
)

View File

@@ -10,9 +10,7 @@ class AccessControlList:
def __init__(self):
"""Init."""
self.acl: Dict[
str, AccessControlList
] = {} # A dictionary of ACL Rules
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
@@ -27,29 +25,16 @@ class AccessControlList:
True if match; False otherwise.
"""
if (
(
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == "ANY"
and _rule.get_dest_ip() == _dest_ip_address
)
or (
_rule.get_source_ip() == _source_ip_address
and _rule.get_dest_ip() == "ANY"
)
or (
_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY"
)
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
):
return True
else:
return False
def is_blocked(
self, _source_ip_address, _dest_ip_address, _protocol, _port
):
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
"""
Checks for rules that block a protocol / port.
@@ -63,15 +48,9 @@ class AccessControlList:
Indicates block if all conditions are satisfied.
"""
for rule_key, rule_value in self.acl.items():
if self.check_address_match(
rule_value, _source_ip_address, _dest_ip_address
):
if (
rule_value.get_protocol() == _protocol
or rule_value.get_protocol() == "ANY"
) and (
str(rule_value.get_port()) == str(_port)
or rule_value.get_port() == "ANY"
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
):
# There's a matching rule. Get the permission
if rule_value.get_permission() == "DENY":
@@ -93,9 +72,7 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
new_rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(new_rule)
self.acl[hash_value] = new_rule
@@ -110,9 +87,7 @@ class AccessControlList:
_protocol: the protocol
_port: the port
"""
rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
# There will not always be something 'popable' since the agent will be trying random things
try:
@@ -124,9 +99,7 @@ class AccessControlList:
"""Removes all rules."""
self.acl.clear()
def get_dictionary_hash(
self, _permission, _source_ip, _dest_ip, _protocol, _port
):
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
"""
Produces a hash value for a rule.
@@ -140,8 +113,6 @@ class AccessControlList:
Returns:
Hash value based on rule parameters.
"""
rule = ACLRule(
_permission, _source_ip, _dest_ip, _protocol, str(_port)
)
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
hash_value = hash(rule)
return hash_value

View File

@@ -5,7 +5,7 @@ import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Dict, Final, Optional, Union
from typing import Dict, Final, Union
from uuid import uuid4
import yaml
@@ -51,16 +51,12 @@ class AgentSessionABC(ABC):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path
)
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(
self._lay_down_config_path
)
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self.output_verbose_level = self._training_config.output_verbose_level
self._env: Primaite
@@ -132,9 +128,7 @@ class AgentSessionABC(ABC):
"learning": {"total_episodes": None, "total_time_steps": None},
"evaluation": {"total_episodes": None, "total_time_steps": None},
"env": {
"training_config": self._training_config.to_dict(
json_serializable=True
),
"training_config": self._training_config.to_dict(json_serializable=True),
"lay_down_config": self._lay_down_config,
},
}
@@ -161,19 +155,11 @@ class AgentSessionABC(ABC):
metadata_dict["end_datetime"] = datetime.now().isoformat()
if not self.is_eval:
metadata_dict["learning"][
"total_episodes"
] = self._env.episode_count # noqa
metadata_dict["learning"][
"total_time_steps"
] = self._env.total_step_count # noqa
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
else:
metadata_dict["evaluation"][
"total_episodes"
] = self._env.episode_count # noqa
metadata_dict["evaluation"][
"total_time_steps"
] = self._env.total_step_count # noqa
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -184,12 +170,9 @@ class AgentSessionABC(ABC):
@abstractmethod
def _setup(self):
_LOGGER.info(
"Welcome to the Primary-level AI Training Environment "
f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(
f"The output directory for this session is: {self.session_path}"
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
)
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
self._write_session_metadata_file()
self._can_learn = True
self._can_evaluate = False
@@ -201,17 +184,11 @@ class AgentSessionABC(ABC):
@abstractmethod
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if self._can_learn:
@@ -225,17 +202,11 @@ class AgentSessionABC(ABC):
@abstractmethod
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
@@ -281,9 +252,7 @@ class AgentSessionABC(ABC):
else:
# Session path does not exist
msg = (
f"Failed to load PrimAITE Session, path does not exist: {path}"
)
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
pass
@@ -354,17 +323,11 @@ class HardCodedAgentSessionABC(AgentSessionABC):
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
_LOGGER.warning("Deterministic agents cannot learn")
@@ -375,27 +338,19 @@ class HardCodedAgentSessionABC(AgentSessionABC):
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
self._env.set_as_eval() # noqa
self.is_eval = True
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
obs = self._env.reset()
for episode in range(episodes):
# Reset env and collect initial observation

View File

@@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
if (
self._training_config.hard_coded_agent_view
== HardCodedAgentView.BASIC
):
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
else:
@@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
port = green_ier.get_port()
# Can be blocked by an ACL or by default (no allow rule exists)
if acl.is_blocked(
source_node_address, dest_node_address, protocol, port
):
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
blocked_green_iers[green_ier_id] = green_ier
return blocked_green_iers
@@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
@@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
dest_node_address = dest_node_id
if protocol != "ANY":
protocol = services_list[
protocol - 1
] # -1 as dont have to account for ANY in list of services
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
matching_rules = acl.get_relevant_rules(
source_node_address, dest_node_address, protocol, port
)
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_allow_acl_rules(
@@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = rule.get_source_ip()
action_source_id = int(
get_node_of_ip(action_source_ip, self._env.nodes)
)
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
action_destination_ip = rule.get_dest_ip()
action_destination_id = int(
get_node_of_ip(
action_destination_ip, self._env.nodes
)
)
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
action_protocol_name = rule.get_protocol()
action_protocol = (
self._env.services_list.index(action_protocol_name)
+ 1
self._env.services_list.index(action_protocol_name) + 1
) # convert name e.g. 'TCP' to index
action_port_name = rule.get_port()
action_port = (
@@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if not found_action:
# Which Green IERS are blocked
blocked_green_iers = self.get_blocked_green_iers(
self._env.green_iers, self._env.acl, self._env.nodes
)
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
for ier_key, ier in blocked_green_iers.items():
# Which ALLOW rules are allowing this IER (none)
allowing_rules = self.get_allow_acl_rules_for_ier(
ier, self._env.acl, self._env.nodes
)
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
# If there are no blocking rules, it may be being blocked by default
# If there is already an allow rule
node_id_to_check = int(ier.get_source_node_id())
service_name_to_check = ier.get_protocol()
service_id_to_check = self._env.services_list.index(
service_name_to_check
)
service_id_to_check = self._env.services_list.index(service_name_to_check)
# Service state of the the source node in the ier
service_state = s[service_id_to_check][node_id_to_check - 1]
@@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
if len(r_obs) == 4: # only 1 service
s = [*s]
number_of_nodes = len(
[i for i in o if i != "NONE"]
) # number of nodes (not links)
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
for service_num, service_states in enumerate(s):
comprimised_states = [
n for n, i in enumerate(service_states) if i == "COMPROMISED"
]
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
if len(comprimised_states) == 0:
# No states are COMPROMISED, try the next service
continue
compromised_node = (
np.random.choice(comprimised_states) + 1
) # +1 as 0 would be any
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
action_decision = "DELETE"
action_permission = "ALLOW"
action_source_ip = compromised_node
# Randomly select a destination ID to block
action_destination_ip = np.random.choice(
list(range(1, number_of_nodes + 1)) + ["ANY"]
)
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
action_destination_ip = (
int(action_destination_ip)
if action_destination_ip != "ANY"
else action_destination_ip
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
)
action_protocol = service_num + 1 # +1 as 0 is any
# Randomly select a port to block

View File

@@ -1,9 +1,5 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_action_node_enum,
transform_change_obs_readable,
)
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
class HardCodedNodeAgent(HardCodedAgentSessionABC):
@@ -93,12 +89,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
if os_state == "OFF":
action_node_id = x + 1
action_node_property = "OPERATING"
property_action = (
"ON" # Why reset it when we can just turn it on
)
action_service_index = (
0 # does nothing isn't relevant for operating state
)
property_action = "ON" # Why reset it when we can just turn it on
action_service_index = 0 # does nothing isn't relevant for operating state
action = [
action_node_id,
action_node_property,

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
from typing import Union
import tensorflow as tf
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.algorithms.ppo import PPOConfig
@@ -14,11 +13,7 @@ from ray.tune.registry import register_env
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
from primaite.common.enums import (
AgentFramework,
AgentIdentifier,
DeepLearningFramework,
)
from primaite.common.enums import AgentFramework, AgentIdentifier
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
@@ -49,10 +44,7 @@ class RLlibAgent(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.RLLIB:
msg = (
f"Expected RLLIB agent_framework, "
f"got {self._training_config.agent_framework}"
)
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
@@ -60,10 +52,7 @@ class RLlibAgent(AgentSessionABC):
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_config_class = A2CConfig
else:
msg = (
"Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier.value}"
)
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
_LOGGER.error(msg)
raise ValueError(msg)
self._agent_config: PPOConfig
@@ -94,12 +83,8 @@ class RLlibAgent(AgentSessionABC):
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._current_result[
"episodes_total"
]
metadata_dict["total_time_steps"] = self._current_result[
"timesteps_total"
]
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
filepath = self.session_path / "session_metadata.json"
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
@@ -122,9 +107,7 @@ class RLlibAgent(AgentSessionABC):
),
)
self._agent_config.training(
train_batch_size=self._training_config.num_steps
)
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
self._agent_config.rollouts(
@@ -132,72 +115,41 @@ class RLlibAgent(AgentSessionABC):
num_envs_per_worker=1,
horizon=self._training_config.num_steps,
)
self._agent: Algorithm = self._agent_config.build(
logger_creator=_custom_log_creator(self.session_path)
)
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (
episode_count == self._training_config.num_episodes
):
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
# Temporarily override train_batch_size and horizon
if time_steps:
self._agent_config.train_batch_size = time_steps
self._agent_config.horizon = time_steps
if not episodes:
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
_LOGGER.info(
f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps..."
)
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
if (
self._training_config.deep_learning_framework
!= DeepLearningFramework.TORCH
):
policy = self._agent.get_policy()
tf.compat.v1.summary.FileWriter(
self.session_path / "ray_results", policy.get_session().graph
)
super().learn()
self._agent.stop()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
raise NotImplementedError

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional, Union
from typing import Union
import numpy as np
from stable_baselines3 import A2C, PPO
@@ -21,10 +21,7 @@ class SB3Agent(AgentSessionABC):
def __init__(self, training_config_path, lay_down_config_path):
super().__init__(training_config_path, lay_down_config_path)
if not self._training_config.agent_framework == AgentFramework.SB3:
msg = (
f"Expected SB3 agent_framework, "
f"got {self._training_config.agent_framework}"
)
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
_LOGGER.error(msg)
raise ValueError(msg)
if self._training_config.agent_identifier == AgentIdentifier.PPO:
@@ -32,10 +29,7 @@ class SB3Agent(AgentSessionABC):
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
self._agent_class = A2C
else:
msg = (
"Expected PPO or A2C agent_identifier, "
f"got {self._training_config.agent_identifier}"
)
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
_LOGGER.error(msg)
raise ValueError(msg)
@@ -64,19 +58,15 @@ class SB3Agent(AgentSessionABC):
self._env,
verbose=self.output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path,
tensorboard_log=str(self._tensorboard_log_path),
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (
episode_count == self._training_config.num_episodes
):
checkpoint_path = (
self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
)
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
@@ -85,57 +75,36 @@ class SB3Agent(AgentSessionABC):
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
self.is_eval = False
_LOGGER.info(
f"Beginning learning for {episodes} episodes @"
f" {time_steps} time steps..."
)
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self.close()
self._env.reset()
self._env.close()
super().learn()
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of steps per episode. Optional. If not
passed, the value from the training config will be used.
:param episodes: The number of episodes. Optional. If not
passed, the value from the training config will be used.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
self._env.set_as_eval()
self.is_eval = True
@@ -144,19 +113,18 @@ class SB3Agent(AgentSessionABC):
else:
deterministic_str = "non-deterministic"
_LOGGER.info(
f"Beginning {deterministic_str} evaluation for "
f"{episodes} episodes @ {time_steps} time steps..."
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
)
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(
obs, deterministic=deterministic
)
action, _states = self._agent.predict(obs, deterministic=deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
self._env.reset()
self._env.close()
super().evaluate()
@classmethod

View File

@@ -1,9 +1,5 @@
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
transform_action_acl_enum,
transform_action_node_enum,
)
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
class RandomAgent(HardCodedAgentSessionABC):

View File

@@ -24,9 +24,7 @@ def transform_action_node_readable(action):
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (
action_node_property == "OS" or action_node_property == "SERVICE"
) and action[2] <= 1:
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -117,11 +115,7 @@ def is_valid_acl_action(action):
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
@@ -173,9 +167,7 @@ def transform_change_obs_readable(obs):
for service in range(3, obs.shape[1]):
# Links bit/s don't have a service state
service_states = [
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
]
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
new_obs.append(service_states)
return new_obs
@@ -247,9 +239,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
return new_obs
def describe_obs_change(
obs1, obs2, num_nodes=10, num_links=10, num_services=1
):
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""
Return string describing change between two observations.
@@ -291,16 +281,9 @@ def _describe_obs_change_helper(obs_change, is_link):
TODO: Typehint params and return.
"""
# Indexes where a change has occured, not including 0th index
index_changed = [
i for i in range(1, len(obs_change)) if obs_change[i] != -1
]
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
# Node pol types, Indexes >= 3 are service nodes
NodePOLTypes = [
NodePOLType(i).name
if i < 3
else NodePOLType(3).name + " " + str(i - 3)
for i in index_changed
]
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
# Account for hardware states, software sattes and links
states = [
LinkStatus(obs_change[i]).name
@@ -367,9 +350,7 @@ def transform_action_node_readable(action):
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (
action_node_property == "OS" or action_node_property == "SERVICE"
) and action[2] <= 1:
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
@@ -397,9 +378,7 @@ def node_action_description(action):
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = (
f"NODE {node_id}, {node_property}, SET TO {property_action}"
)
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
elif node_property == "SERVICE":
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
else:
@@ -522,11 +501,7 @@ def is_valid_acl_action(action):
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":

View File

@@ -56,9 +56,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
print(re.sub(r"\n*", "", line))
_LogLevel = Enum(
"LogLevel", {k: k for k in logging._levelToName.values()}
) # noqa
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
@app.command()
@@ -124,21 +122,12 @@ def setup(overwrite_existing: bool = True):
app_dirs = PlatformDirs(appname="primaite")
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
pkg_config_path = Path(
pkg_resources.resource_filename(
"primaite", "setup/_package_data/primaite_config.yaml"
)
)
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, user_config_path)
from primaite import getLogger
from primaite.setup import (
old_installation_clean_up,
reset_demo_notebooks,
reset_example_configs,
setup_app_dirs,
)
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs
_LOGGER = getLogger(__name__)
@@ -188,9 +177,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
@app.command()
def plotly_template(
template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None
):
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
"""
View or set the plotly template for Session plots.
@@ -208,14 +195,10 @@ def plotly_template(
primaite_config = yaml.safe_load(file)
if template:
primaite_config["session"]["outputs"]["plots"][
"template"
] = template.value
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
with open(user_config_path, "w") as file:
yaml.dump(primaite_config, file)
print(f"PrimAITE plotly template: {template.value}")
else:
template = primaite_config["session"]["outputs"]["plots"][
"template"
]
template = primaite_config["session"]["outputs"]["plots"]["template"]
print(f"PrimAITE plotly template: {template}")

View File

@@ -8,14 +8,10 @@ from primaite import getLogger, USERS_CONFIG_DIR
_LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = (
USERS_CONFIG_DIR / "example_config" / "lay_down"
)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
def convert_legacy_lay_down_config_dict(
legacy_config_dict: Dict[str, Any]
) -> Dict[str, Any]:
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a legacy lay down config dict to the new format.

View File

@@ -20,9 +20,7 @@ from primaite.common.enums import (
_LOGGER = getLogger(__name__)
_EXAMPLE_TRAINING: Final[Path] = (
USERS_CONFIG_DIR / "example_config" / "training"
)
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
def main_training_config_path() -> Path:
@@ -68,9 +66,7 @@ class TrainingConfig:
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes"
observation_space: dict = field(
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
)
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
"The observation space config dict"
time_delay: int = 10
@@ -180,9 +176,7 @@ class TrainingConfig:
"The time taken to scan the file system"
@classmethod
def from_dict(
cls, config_dict: Dict[str, Union[str, int, bool]]
) -> TrainingConfig:
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
"""
Create an instance of TrainingConfig from a dict.
@@ -238,9 +232,7 @@ class TrainingConfig:
return tc
def load(
file_path: Union[str, Path], legacy_file: bool = False
) -> TrainingConfig:
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
"""
Read in a training config yaml file.
@@ -271,10 +263,7 @@ def load(
try:
return TrainingConfig.from_dict(config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "
f"from the training config file {file_path}"
)
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
_LOGGER.critical(msg, exc_info=True)
raise e
msg = f"Cannot load the training config as it does not exist: {file_path}"
@@ -314,9 +303,7 @@ def convert_legacy_training_config_dict(
"output_verbose_level": output_verbose_level.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
legacy_config_dict["sessionType"] = session_type_map[
legacy_config_dict["sessionType"]
]
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
if new_key:

View File

@@ -77,9 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(
observation_shape, dtype=self._DATA_TYPE
)
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
@@ -94,12 +92,8 @@ class NodeLinkTable(AbstractObservationComponent):
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][
2
] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][3] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
@@ -107,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent):
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
self.current_observation[item_index][service_index] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
@@ -129,9 +121,7 @@ class NodeLinkTable(AbstractObservationComponent):
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
@@ -203,9 +193,7 @@ class NodeStatuses(AbstractObservationComponent):
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(
service
).value
service_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
@@ -269,11 +257,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels]
* self.env.num_links
* self._entries_per_link
)
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
@@ -292,9 +276,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [
protocol.get_load() for protocol in link.protocol_list
]
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
@@ -302,9 +284,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 2)
) + 1
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
obs.append(int(traffic_level))

View File

@@ -12,13 +12,11 @@ from matplotlib import pyplot as plt
from primaite import getLogger
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import (
is_valid_acl_action_extra,
is_valid_node_action,
)
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
AgentFramework,
FileSystemState,
HardwareState,
NodePOLInitiator,
@@ -37,18 +35,13 @@ from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node import Node
from primaite.nodes.node_state_instruction_green import (
NodeStateInstructionGreen,
)
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.passive_node import PassiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.green_pol import apply_iers, apply_node_pol
from primaite.pol.ier import IER
from primaite.pol.red_agent_pol import (
apply_red_agent_iers,
apply_red_agent_node_pol,
)
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
from primaite.transactions.transaction import Transaction
from primaite.utils.session_output_writer import SessionOutputWriter
@@ -85,9 +78,7 @@ class Primaite(Env):
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self.training_config: TrainingConfig = training_config.load(
training_config_path
)
self.training_config: TrainingConfig = training_config.load(training_config_path)
_LOGGER.info(f"Using: {str(self.training_config)}")
# Number of steps in an episode
@@ -238,25 +229,22 @@ class Primaite(Env):
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.error(
f"Invalid action type selected: {self.training_config.action_type}"
)
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
self.episode_av_reward_writer = SessionOutputWriter(
self, transaction_writer=False, learning_session=True
)
self.transaction_writer = SessionOutputWriter(
self, transaction_writer=True, learning_session=True
)
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
@property
def actual_episode_count(self) -> int:
"""Shifts the episode_count by -1 for RLlib."""
if self.training_config.agent_framework is AgentFramework.RLLIB:
return self.episode_count - 1
return self.episode_count
def set_as_eval(self):
"""Set the writers to write to eval directories."""
self.episode_av_reward_writer = SessionOutputWriter(
self, transaction_writer=False, learning_session=False
)
self.transaction_writer = SessionOutputWriter(
self, transaction_writer=True, learning_session=False
)
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
self.episode_count = 0
self.step_count = 0
self.total_step_count = 0
@@ -268,8 +256,8 @@ class Primaite(Env):
Returns:
Environment observation space (reset)
"""
if self.episode_count > 0:
csv_data = self.episode_count, self.average_reward
if self.actual_episode_count > 0:
csv_data = self.actual_episode_count, self.average_reward
self.episode_av_reward_writer.write(csv_data)
self.episode_count += 1
@@ -291,6 +279,7 @@ class Primaite(Env):
# Update observations space and return
self.update_environent_obs()
return self.env_obs
def step(self, action):
@@ -319,9 +308,7 @@ class Primaite(Env):
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
self.agent_identifier, self.episode_count, self.step_count
)
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
# Load the initial observation space into the transaction
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
# Load the action space into the transaction
@@ -350,9 +337,7 @@ class Primaite(Env):
self.nodes_post_pol = copy.deepcopy(self.nodes)
self.links_post_pol = copy.deepcopy(self.links)
# Reference
apply_node_pol(
self.nodes_reference, self.node_pol, self.step_count
) # Node PoL
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
apply_iers(
self.network_reference,
self.nodes_reference,
@@ -371,9 +356,7 @@ class Primaite(Env):
self.acl,
self.step_count,
)
apply_red_agent_node_pol(
self.nodes, self.red_iers, self.red_node_pol, self.step_count
)
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
# Take snapshots of nodes and links
self.nodes_post_red = copy.deepcopy(self.nodes)
self.links_post_red = copy.deepcopy(self.links)
@@ -389,11 +372,7 @@ class Primaite(Env):
self.step_count,
self.training_config,
)
_LOGGER.debug(
f"Episode: {self.episode_count}, "
f"Step {self.step_count}, "
f"Reward: {reward}"
)
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -401,10 +380,7 @@ class Primaite(Env):
# For evaluation, need to trigger the done value = True when
# step count is reached in order to prevent neverending episode
done = True
_LOGGER.info(
f"Episode: {self.episode_count}, "
f"Average Reward: {self.average_reward}"
)
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
# Load the reward into the transaction
transaction.reward = reward
@@ -417,11 +393,21 @@ class Primaite(Env):
transaction.obs_space_post = copy.deepcopy(self.env_obs)
# Write transaction to file
if self.actual_episode_count > 0:
self.transaction_writer.write(transaction)
# Return
return self.env_obs, reward, done, self.step_info
def close(self):
"""Override parent close and close writers."""
# Close files if last episode/step
# if self.can_finish:
super().close()
self.transaction_writer.close()
self.episode_av_reward_writer.close()
def init_acl(self):
"""Initialise the Access Control List."""
self.acl.remove_all_rules()
@@ -431,12 +417,7 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
_LOGGER.debug("Link ID: " + link_value.get_id())
for protocol in link_value.protocol_list:
_LOGGER.debug(
" Protocol: "
+ protocol.get_name().name
+ ", Load: "
+ str(protocol.get_load())
)
_LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
def interpret_action_and_apply(self, _action):
"""
@@ -450,13 +431,9 @@ class Primaite(Env):
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
_LOGGER.error("Invalid action type found")
@@ -541,10 +518,7 @@ class Primaite(Env):
elif property_action == 2:
# Repair
# You cannot repair a destroyed file system - it needs restoring
if (
node.file_system_state_actual
!= FileSystemState.DESTROYED
):
if node.file_system_state_actual != FileSystemState.DESTROYED:
node.set_file_system_state(FileSystemState.REPAIRING)
elif property_action == 3:
# Restore
@@ -587,9 +561,7 @@ class Primaite(Env):
acl_rule_source = "ANY"
else:
node = list(self.nodes.values())[action_source_ip - 1]
if isinstance(node, ServiceNode) or isinstance(
node, ActiveNode
):
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
acl_rule_source = node.ip_address
else:
return
@@ -598,9 +570,7 @@ class Primaite(Env):
acl_rule_destination = "ANY"
else:
node = list(self.nodes.values())[action_destination_ip - 1]
if isinstance(node, ServiceNode) or isinstance(
node, ActiveNode
):
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
acl_rule_destination = node.ip_address
else:
return
@@ -685,9 +655,7 @@ class Primaite(Env):
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
self.obs_handler = ObservationsHandler.from_config(
self, self.obs_config
)
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
return self.obs_handler.space, self.obs_handler.current_observation
@@ -794,9 +762,7 @@ class Primaite(Env):
service_protocol = service["name"]
service_port = service["port"]
service_state = SoftwareState[service["state"]]
node.add_service(
Service(service_protocol, service_port, service_state)
)
node.add_service(Service(service_protocol, service_port, service_state))
else:
# Bad formatting
pass
@@ -849,9 +815,7 @@ class Primaite(Env):
dest_node_ref: Node = self.nodes_reference[link_destination]
# Add link to network (reference)
self.network_reference.add_edge(
source_node_ref, dest_node_ref, id=link_name
)
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
# Add link to link dictionary (reference)
self.links_reference[link_name] = Link(
@@ -1126,9 +1090,7 @@ class Primaite(Env):
# All nodes have these parameters
node_id = item["node_id"]
node_class = item["node_class"]
node_hardware_state: HardwareState = HardwareState[
item["hardware_state"]
]
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
node: NodeUnion = self.nodes[node_id]
node_ref = self.nodes_reference[node_id]
@@ -1249,11 +1211,7 @@ class Primaite(Env):
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}

View File

@@ -41,29 +41,19 @@ def calculate_reward_function(
reference_node = reference_nodes[node_key]
# Hardware State
reward_value += score_node_operating_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
# Software State
if isinstance(final_node, ActiveNode) or isinstance(
final_node, ServiceNode
):
reward_value += score_node_os_state(
final_node, initial_node, reference_node, config_values
)
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
# Service State
if isinstance(final_node, ServiceNode):
reward_value += score_node_service_state(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
# File System State
if isinstance(final_node, ActiveNode):
reward_value += score_node_file_system(
final_node, initial_node, reference_node, config_values
)
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
# Go through each red IER - penalise if it is running
for ier_key, ier_value in red_iers.items():
@@ -82,10 +72,7 @@ def calculate_reward_function(
if step_count >= start_step and step_count <= stop_step:
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
)
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
if live_blocked and not reference_blocked:
reward_value += ier_reward
@@ -107,9 +94,7 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(
final_node, initial_node, reference_node, config_values
):
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the hardware state of a node.
@@ -158,9 +143,7 @@ def score_node_operating_state(
return score
def score_node_os_state(
final_node, initial_node, reference_node, config_values
):
def score_node_os_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the Software State of a node.
@@ -211,9 +194,7 @@ def score_node_os_state(
return score
def score_node_service_state(
final_node, initial_node, reference_node, config_values
):
def score_node_service_state(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the service state(s) of a node.
@@ -285,9 +266,7 @@ def score_node_service_state(
return score
def score_node_file_system(
final_node, initial_node, reference_node, config_values
):
def score_node_file_system(final_node, initial_node, reference_node, config_values):
"""
Calculates score relating to the file system state of a node.

View File

@@ -8,9 +8,7 @@ from primaite.common.protocol import Protocol
class Link(object):
"""Link class."""
def __init__(
self, _id, _bandwidth, _source_node_name, _dest_node_name, _services
):
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
"""
Init.

View File

@@ -32,11 +32,7 @@ if __name__ == "__main__":
parser.add_argument("--ldc")
args = parser.parse_args()
if not args.tc:
_LOGGER.error(
"Please provide a training config file using the --tc " "argument"
)
_LOGGER.error("Please provide a training config file using the --tc " "argument")
if not args.ldc:
_LOGGER.error(
"Please provide a lay down config file using the --ldc " "argument"
)
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -3,13 +3,7 @@
import logging
from typing import Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.config.training_config import TrainingConfig
from primaite.nodes.node import Node
@@ -44,9 +38,7 @@ class ActiveNode(Node):
:param file_system_state: The node file system state
:param config_values: The config values
"""
super().__init__(
node_id, name, node_type, priority, hardware_state, config_values
)
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
self.ip_address: str = ip_address
# Related to Software
self._software_state: SoftwareState = software_state
@@ -87,9 +79,7 @@ class ActiveNode(Node):
f"Node.software_state:{self._software_state}"
)
def set_software_state_if_not_compromised(
self, software_state: SoftwareState
):
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
"""
Sets Software State if the node is not compromised.
@@ -100,9 +90,7 @@ class ActiveNode(Node):
if self._software_state != SoftwareState.COMPROMISED:
self._software_state = software_state
if software_state == SoftwareState.PATCHING:
self.patching_count = (
self.config_values.os_patching_duration
)
self.patching_count = self.config_values.os_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so OS State cannot be changed."
@@ -129,14 +117,10 @@ class ActiveNode(Node):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
@@ -149,9 +133,7 @@ class ActiveNode(Node):
f"Node.file_system_state.actual:{self.file_system_state_actual}"
)
def set_file_system_state_if_not_compromised(
self, file_system_state: FileSystemState
):
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
"""
Sets the file system state (actual and observed) if not in a compromised state.
@@ -168,14 +150,10 @@ class ActiveNode(Node):
self.file_system_state_actual = file_system_state
if file_system_state == FileSystemState.REPAIRING:
self.file_system_action_count = (
self.config_values.file_system_repairing_limit
)
self.file_system_action_count = self.config_values.file_system_repairing_limit
self.file_system_state_observed = FileSystemState.REPAIRING
elif file_system_state == FileSystemState.RESTORING:
self.file_system_action_count = (
self.config_values.file_system_restoring_limit
)
self.file_system_action_count = self.config_values.file_system_restoring_limit
self.file_system_state_observed = FileSystemState.RESTORING
elif file_system_state == FileSystemState.GOOD:
self.file_system_state_observed = FileSystemState.GOOD
@@ -191,9 +169,7 @@ class ActiveNode(Node):
def start_file_system_scan(self):
"""Starts a file system scan."""
self.file_system_scanning = True
self.file_system_scanning_count = (
self.config_values.file_system_scanning_limit
)
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
def update_file_system_state(self):
"""Updates file system status based on scanning/restore/repair cycle."""
@@ -212,10 +188,7 @@ class ActiveNode(Node):
self.file_system_state_observed = FileSystemState.GOOD
# Scanning updates
if (
self.file_system_scanning == True
and self.file_system_scanning_count < 0
):
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
self.file_system_state_observed = self.file_system_state_actual
self.file_system_scanning = False
self.file_system_scanning_count = 0

View File

@@ -32,9 +32,7 @@ class NodeStateInstructionGreen(object):
self.end_step = _end_step
self.node_id = _node_id
self.node_pol_type = _node_pol_type
self.service_name = (
_service_name # Not used when not a service instruction
)
self.service_name = _service_name # Not used when not a service instruction
self.state = _state
def get_start_step(self):

View File

@@ -42,9 +42,7 @@ class NodeStateInstructionRed(object):
self.target_node_id = _target_node_id
self.initiator = _pol_initiator
self.pol_type: NodePOLType = _pol_type
self.service_name = (
pol_protocol # Not used when not a service instruction
)
self.service_name = pol_protocol # Not used when not a service instruction
self.state = _pol_state
self.source_node_id = _pol_source_node_id
self.source_node_service = _pol_source_node_service

View File

@@ -28,9 +28,7 @@ class PassiveNode(Node):
:param config_values: Config values.
"""
# Pass through to Super for now
super().__init__(
node_id, name, node_type, priority, hardware_state, config_values
)
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
@property
def ip_address(self) -> str:

View File

@@ -3,13 +3,7 @@
import logging
from typing import Dict, Final
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -110,9 +104,7 @@ class ServiceNode(ActiveNode):
return False
return False
def set_service_state(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -130,9 +122,7 @@ class ServiceNode(ActiveNode):
) or software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "
@@ -143,9 +133,7 @@ class ServiceNode(ActiveNode):
f"Node.services[<key>].software_state:{software_state}"
)
def set_service_state_if_not_compromised(
self, protocol_name: str, software_state: SoftwareState
):
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
"""
Sets the software_state of a service (protocol) on the node.
@@ -161,9 +149,7 @@ class ServiceNode(ActiveNode):
if service_value.software_state != SoftwareState.COMPROMISED:
service_value.software_state = software_state
if software_state == SoftwareState.PATCHING:
service_value.patching_count = (
self.config_values.service_patching_duration
)
service_value.patching_count = self.config_values.service_patching_duration
else:
_LOGGER.info(
f"The Nodes hardware state is OFF so the state of a service "

View File

@@ -6,17 +6,10 @@ from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
NodePOLType,
NodeType,
SoftwareState,
)
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_green import (
NodeStateInstructionGreen,
)
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
@@ -93,9 +86,7 @@ def apply_iers(
and source_node.software_state != SoftwareState.PATCHING
):
if source_node.has_service(protocol):
if source_node.service_running(
protocol
) and not source_node.service_is_overwhelmed(protocol):
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
source_valid = True
else:
source_valid = False
@@ -110,10 +101,7 @@ def apply_iers(
# 2. Check the dest node situation
if dest_node.node_type == NodeType.SWITCH:
# It's a switch
if (
dest_node.hardware_state == HardwareState.ON
and dest_node.software_state != SoftwareState.PATCHING
):
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
dest_valid = True
else:
# IER no longer valid
@@ -123,14 +111,9 @@ def apply_iers(
pass
else:
# It's not a switch or an actuator (so active node)
if (
dest_node.hardware_state == HardwareState.ON
and dest_node.software_state != SoftwareState.PATCHING
):
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
if dest_node.has_service(protocol):
if dest_node.service_running(
protocol
) and not dest_node.service_is_overwhelmed(protocol):
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
dest_valid = True
else:
dest_valid = False
@@ -143,9 +126,7 @@ def apply_iers(
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(
source_node.ip_address, dest_node.ip_address, protocol, port
)
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
@@ -176,10 +157,7 @@ def apply_iers(
# We might have a switch in the path, so check all nodes are operational
for node in path_node_list:
if (
node.hardware_state != HardwareState.ON
or node.software_state == SoftwareState.PATCHING
):
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
path_valid = False
if path_valid:
@@ -191,15 +169,11 @@ def apply_iers(
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
)
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (
link.get_current_load() + load
) > link.get_bandwidth():
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
@@ -226,9 +200,7 @@ def apply_iers(
else:
# One of the nodes is not operational
if _VERBOSE:
print(
"Path not valid - one or more nodes not operational"
)
print("Path not valid - one or more nodes not operational")
pass
else:
@@ -243,9 +215,7 @@ def apply_iers(
def apply_node_pol(
nodes: Dict[str, NodeUnion],
node_pol: Dict[
any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]
],
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
step: int,
):
"""
@@ -277,22 +247,16 @@ def apply_node_pol(
elif node_pol_type == NodePOLType.OS:
# Change OS state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ActiveNode) or isinstance(
node, ServiceNode
):
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_software_state_if_not_compromised(state)
elif node_pol_type == NodePOLType.SERVICE:
# Change a service state
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
if isinstance(node, ServiceNode):
node.set_service_state_if_not_compromised(
service_name, state
)
node.set_service_state_if_not_compromised(service_name, state)
else:
# Change the file system status
if isinstance(node, ActiveNode) or isinstance(
node, ServiceNode
):
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
node.set_file_system_state_if_not_compromised(state)
else:
# PoL is not valid in this time step

View File

@@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path
from primaite.acl.access_control_list import AccessControlList
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
NodePOLInitiator,
NodePOLType,
NodeType,
SoftwareState,
)
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
@@ -83,10 +77,7 @@ def apply_red_agent_iers(
if source_node.hardware_state == HardwareState.ON:
if source_node.has_service(protocol):
# Red agents IERs can only be valid if the source service is in a compromised state
if (
source_node.get_service_state(protocol)
== SoftwareState.COMPROMISED
):
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
source_valid = True
else:
source_valid = False
@@ -124,9 +115,7 @@ def apply_red_agent_iers(
dest_valid = False
# 3. Check that the ACL doesn't block it
acl_block = acl.is_blocked(
source_node.ip_address, dest_node.ip_address, protocol, port
)
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
if acl_block:
if _VERBOSE:
print(
@@ -170,15 +159,11 @@ def apply_red_agent_iers(
# Check that the link capacity is not exceeded by the new load
while count < path_node_list_length - 1:
# Get the link between the next two nodes
edge_dict = network.get_edge_data(
path_node_list[count], path_node_list[count + 1]
)
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
link_id = edge_dict[0].get("id")
link = links[link_id]
# Check whether the new load exceeds the bandwidth
if (
link.get_current_load() + load
) > link.get_bandwidth():
if (link.get_current_load() + load) > link.get_bandwidth():
link_capacity_exceeded = True
if _VERBOSE:
print("Link capacity exceeded")
@@ -203,23 +188,16 @@ def apply_red_agent_iers(
# This IER is now valid, so set it to running
ier_value.set_is_running(True)
if _VERBOSE:
print(
"Red IER was allowed to run in step "
+ str(step)
)
print("Red IER was allowed to run in step " + str(step))
else:
# One of the nodes is not operational
if _VERBOSE:
print(
"Path not valid - one or more nodes not operational"
)
print("Path not valid - one or more nodes not operational")
pass
else:
if _VERBOSE:
print(
"Red IER was NOT allowed to run in step " + str(step)
)
print("Red IER was NOT allowed to run in step " + str(step))
print("Source, Dest or ACL were not valid")
pass
# ------------------------------------
@@ -258,9 +236,7 @@ def apply_red_agent_node_pol(
state = node_instruction.get_state()
source_node_id = node_instruction.get_source_node_id()
source_node_service_name = node_instruction.get_source_node_service()
source_node_service_state_value = (
node_instruction.get_source_node_service_state()
)
source_node_service_state_value = node_instruction.get_source_node_service_state()
passed_checks = False
@@ -274,9 +250,7 @@ def apply_red_agent_node_pol(
passed_checks = True
elif initiator == NodePOLInitiator.IER:
# Need to check there is a red IER incoming
passed_checks = is_red_ier_incoming(
target_node, iers, pol_type
)
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
elif initiator == NodePOLInitiator.SERVICE:
# Need to check the condition of a service on another node
source_node = nodes[source_node_id]
@@ -304,9 +278,7 @@ def apply_red_agent_node_pol(
target_node.hardware_state = state
elif pol_type == NodePOLType.OS:
# Change OS state
if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.software_state = state
elif pol_type == NodePOLType.SERVICE:
# Change a service state
@@ -314,15 +286,11 @@ def apply_red_agent_node_pol(
target_node.set_service_state(service_name, state)
else:
# Change the file system status
if isinstance(target_node, ActiveNode) or isinstance(
target_node, ServiceNode
):
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
target_node.set_file_system_state(state)
else:
if _VERBOSE:
print(
"Node Red Agent PoL not allowed - did not pass checks"
)
print("Node Red Agent PoL not allowed - did not pass checks")
else:
# PoL is not valid in this time step
pass
@@ -337,10 +305,7 @@ def is_red_ier_incoming(node, iers, node_pol_type):
node_id = node.node_id
for ier_key, ier_value in iers.items():
if (
ier_value.get_is_running()
and ier_value.get_dest_node_id() == node_id
):
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
if (
node_pol_type == NodePOLType.OPERATING
or node_pol_type == NodePOLType.OS

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, Final, Optional, Union
from typing import Dict, Final, Union
from primaite import getLogger
from primaite.agents.agent import AgentSessionABC
@@ -9,18 +9,8 @@ from primaite.agents.hardcoded_acl import HardCodedACLAgent
from primaite.agents.hardcoded_node import HardCodedNodeAgent
from primaite.agents.rllib import RLlibAgent
from primaite.agents.sb3 import SB3Agent
from primaite.agents.simple import (
DoNothingACLAgent,
DoNothingNodeAgent,
DummyAgent,
RandomAgent,
)
from primaite.common.enums import (
ActionType,
AgentFramework,
AgentIdentifier,
SessionType,
)
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
from primaite.config import lay_down_config, training_config
from primaite.config.training_config import TrainingConfig
@@ -49,16 +39,12 @@ class PrimaiteSession:
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
self._training_config: Final[TrainingConfig] = training_config.load(
self._training_config_path
)
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._lay_down_config: Dict = lay_down_config.load(
self._lay_down_config_path
)
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
self._agent_session: AgentSessionABC = None # noqa
self.session_path: Path = None # noqa
@@ -69,28 +55,16 @@ class PrimaiteSession:
def setup(self):
"""Performs the session setup."""
if self._training_config.agent_framework == AgentFramework.CUSTOM:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}"
)
if (
self._training_config.agent_identifier
== AgentIdentifier.HARDCODED
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.HARDCODED}"
)
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
if self._training_config.action_type == ActionType.NODE:
# Deterministic Hardcoded Agent with Node Action Space
self._agent_session = HardCodedNodeAgent(
self._training_config_path, self._lay_down_config_path
)
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = HardCodedACLAgent(
self._training_config_path, self._lay_down_config_path
)
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -100,24 +74,14 @@ class PrimaiteSession:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif (
self._training_config.agent_identifier
== AgentIdentifier.DO_NOTHING
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.DO_NOTHINGD}"
)
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHINGD}")
if self._training_config.action_type == ActionType.NODE:
self._agent_session = DoNothingNodeAgent(
self._training_config_path, self._lay_down_config_path
)
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ACL:
# Deterministic Hardcoded Agent with ACL Action Space
self._agent_session = DoNothingACLAgent(
self._training_config_path, self._lay_down_config_path
)
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.action_type == ActionType.ANY:
# Deterministic Hardcoded Agent with ANY Action Space
@@ -127,49 +91,26 @@ class PrimaiteSession:
# Invalid AgentIdentifier ActionType combo
raise ValueError
elif (
self._training_config.agent_identifier
== AgentIdentifier.RANDOM
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.RANDOM}"
)
self._agent_session = RandomAgent(
self._training_config_path, self._lay_down_config_path
)
elif (
self._training_config.agent_identifier == AgentIdentifier.DUMMY
):
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Identifier ="
f" {AgentIdentifier.DUMMY}"
)
self._agent_session = DummyAgent(
self._training_config_path, self._lay_down_config_path
)
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
else:
# Invalid AgentFramework AgentIdentifier combo
raise ValueError
elif self._training_config.agent_framework == AgentFramework.SB3:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}"
)
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
# Stable Baselines3 Agent
self._agent_session = SB3Agent(
self._training_config_path, self._lay_down_config_path
)
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
elif self._training_config.agent_framework == AgentFramework.RLLIB:
_LOGGER.debug(
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}"
)
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
# Ray RLlib Agent
self._agent_session = RLlibAgent(
self._training_config_path, self._lay_down_config_path
)
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
else:
# Invalid AgentFramework
@@ -182,35 +123,27 @@ class PrimaiteSession:
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Train the agent.
:param time_steps: The number of time steps per episode.
:param episodes: The number of episodes.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.EVAL:
self._agent_session.learn(time_steps, episodes, **kwargs)
self._agent_session.learn(**kwargs)
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
**kwargs,
):
"""
Evaluate the agent.
:param time_steps: The number of time steps per episode.
:param episodes: The number of episodes.
:param kwargs: Any agent-framework specific key word args.
"""
if not self._training_config.session_type == SessionType.TRAIN:
self._agent_session.evaluate(time_steps, episodes, **kwargs)
self._agent_session.evaluate(**kwargs)
def close(self):
"""Closes the agent."""

View File

@@ -18,23 +18,17 @@ def run(overwrite_existing: bool = True):
:param overwrite_existing: A bool to toggle replacing existing edited
notebooks on or off.
"""
notebooks_package_data_root = pkg_resources.resource_filename(
"primaite", "notebooks/_package_data"
)
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
for subdir, dirs, files in os.walk(notebooks_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(
fp, notebooks_package_data_root
).split(os.sep)
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
target_fp = NOTEBOOKS_DIR / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()
if overwrite_existing and not copy_file:
copy_file = (not filecmp.cmp(fp, target_fp)) and (
".ipynb_checkpoints" not in str(target_fp)
)
copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp))
if copy_file:
shutil.copy2(fp, target_fp)

View File

@@ -17,16 +17,12 @@ def run(overwrite_existing=True):
:param overwrite_existing: A bool to toggle replacing existing edited
config on or off.
"""
configs_package_data_root = pkg_resources.resource_filename(
"primaite", "config/_package_data"
)
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
for subdir, dirs, files in os.walk(configs_package_data_root):
for file in files:
fp = os.path.join(subdir, file)
path_split = os.path.relpath(fp, configs_package_data_root).split(
os.sep
)
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
target_fp.parent.mkdir(exist_ok=True, parents=True)
copy_file = not target_fp.is_file()

View File

@@ -76,12 +76,8 @@ class Transaction(object):
row = (
row
+ _turn_action_space_to_array(self.action_space)
+ _turn_obs_space_to_array(
self.obs_space_pre, obs_assets, obs_features
)
+ _turn_obs_space_to_array(
self.obs_space_post, obs_assets, obs_features
)
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features)
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features)
)
return header, row

View File

@@ -51,9 +51,7 @@ class SessionOutputWriter:
self._first_write: bool = True
def _init_csv_writer(self):
self._csv_file = open(
self._csv_file_path, "w", encoding="UTF8", newline=""
)
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
self._csv_writer = csv.writer(self._csv_file)

View File

@@ -57,8 +57,6 @@ class TempPrimaiteSession(PrimaiteSession):
return self
def __exit__(self, type, value, tb):
del self._agent_session._env.episode_av_reward_writer
del self._agent_session._env.transaction_writer
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
@@ -112,9 +110,7 @@ def temp_primaite_session(request):
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch(
"primaite.agents.agent.get_session_path", get_temp_session_path
) as mck:
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
@@ -130,9 +126,7 @@ def temp_session_path() -> Path:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = (
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
)
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path

View File

@@ -16,9 +16,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = (
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
)
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path

View File

@@ -95,8 +95,6 @@ def test_rule_hash():
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
hash_value_local = hash(rule)
hash_value_remote = acl.get_dictionary_hash(
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
)
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
assert hash_value_local == hash_value_remote

View File

@@ -60,9 +60,7 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state):
1,
)
active_node.set_software_state_if_not_compromised(
SoftwareState.OVERWHELMED
)
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
assert active_node.software_state == expected_state
@@ -100,9 +98,7 @@ def test_file_system_change(operating_state, expected_state):
(HardwareState.ON, FileSystemState.CORRUPT),
],
)
def test_file_system_change_if_not_compromised(
operating_state, expected_state
):
def test_file_system_change_if_not_compromised(operating_state, expected_state):
"""
Test that a node cannot change its file system state.
@@ -120,8 +116,6 @@ def test_file_system_change_if_not_compromised(
1,
)
active_node.set_file_system_state_if_not_compromised(
FileSystemState.CORRUPT
)
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
assert active_node.file_system_state_actual == expected_state

View File

@@ -2,11 +2,7 @@
import numpy as np
import pytest
from primaite.environment.observations import (
NodeLinkTable,
NodeStatuses,
ObservationsHandler,
)
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
from tests import TEST_CONFIG_ROOT
@@ -127,9 +123,7 @@ class TestNodeLinkTable:
with temp_primaite_session as session:
env = session.env
# act = np.asarray([0,])
obs, reward, done, info = env.step(
0
) # apply the 'do nothing' action
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs,
@@ -192,17 +186,15 @@ class TestNodeStatuses:
with temp_primaite_session as session:
env = session.env
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]
)
print(obs)
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
@pytest.mark.parametrize(
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
]
],

View File

@@ -36,18 +36,12 @@ def test_primaite_session(temp_primaite_session):
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":
assert (
"all_transactions" in file.name
or "average_reward_per_episode" in file.name
)
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
# Check that both the transactions and av reward csv files exist
for file in session.evaluation_path.iterdir():
if file.suffix == ".csv":
assert (
"all_transactions" in file.name
or "average_reward_per_episode" in file.name
)
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
_LOGGER.debug("Inspecting files in temp session path...")
for dir_path, dir_names, file_names in os.walk(session_path):

View File

@@ -1,13 +1,7 @@
"""Used to test Active Node functions."""
import pytest
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -18,9 +12,7 @@ from primaite.nodes.service_node import ServiceNode
"starting_operating_state, expected_operating_state",
[(HardwareState.RESETTING, HardwareState.ON)],
)
def test_node_resets_correctly(
starting_operating_state, expected_operating_state
):
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
"""Tests that a node resets correctly."""
active_node = ActiveNode(
node_id="0",
@@ -59,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
file_system_state="GOOD",
config_values=1,
)
service_attributes = Service(
name="node", port="80", software_state=SoftwareState.COMPROMISED
)
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
service_node.add_service(service_attributes)
for x in range(5):

View File

@@ -45,9 +45,7 @@ def test_service_state_change(operating_state, expected_state):
(HardwareState.ON, SoftwareState.OVERWHELMED),
],
)
def test_service_state_change_if_not_comprised(
operating_state, expected_state
):
def test_service_state_change_if_not_comprised(operating_state, expected_state):
"""
Test that a node cannot change the state of a running service.
@@ -67,8 +65,6 @@ def test_service_state_change_if_not_comprised(
service = Service("TCP", 80, SoftwareState.GOOD)
service_node.add_service(service)
service_node.set_service_state_if_not_compromised(
"TCP", SoftwareState.OVERWHELMED
)
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
assert service_node.get_service_state("TCP") == expected_state

View File

@@ -18,7 +18,6 @@ def run_generic_set_actions(env: Primaite):
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1]
# Creates an ACL rule
@@ -86,8 +85,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
"temp_primaite_session",
[
[
TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
]
],

View File

@@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT
def test_legacy_lay_down_config_yaml_conversion():
"""Tests the conversion of legacy lay down config files."""
legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
with open(legacy_path, "r") as file:
legacy_dict = yaml.safe_load(file)
@@ -16,9 +16,7 @@ def test_legacy_lay_down_config_yaml_conversion():
with open(new_path, "r") as file:
new_dict = yaml.safe_load(file)
converted_dict = training_config.convert_legacy_training_config_dict(
legacy_dict
)
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
for key, value in new_dict.items():
assert converted_dict[key] == value
@@ -26,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion():
def test_create_config_values_main_from_file():
"""Tests creating an instance of TrainingConfig from file."""
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
training_config.load(new_path)
def test_create_config_values_main_from_legacy_file():
"""Tests creating an instance of TrainingConfig from legacy file."""
new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
training_config.load(new_path, legacy_file=True)