#917 - Fixed the RLlib integration
- Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files.
This commit is contained in:
@@ -13,7 +13,7 @@ repos:
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [ "--line-length=79" ]
|
||||
args: [ "--line-length=120" ]
|
||||
additional_dependencies:
|
||||
- jupyter
|
||||
- repo: http://github.com/pycqa/isort
|
||||
|
||||
@@ -72,9 +72,9 @@ primaite = "primaite.cli:app"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 79
|
||||
line_length = 120
|
||||
force_sort_within_sections = "False"
|
||||
order_by_type = "False"
|
||||
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
line-length = 120
|
||||
|
||||
@@ -19,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite")
|
||||
def _get_primaite_config():
|
||||
config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml"
|
||||
if not config_path.exists():
|
||||
config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
)
|
||||
)
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
with open(config_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
log_level_map = {
|
||||
@@ -34,9 +30,7 @@ def _get_primaite_config():
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[
|
||||
primaite_config["logging"]["log_level"]
|
||||
]
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
|
||||
@@ -82,14 +76,9 @@ class _LevelFormatter(Formatter):
|
||||
super().__init__()
|
||||
|
||||
if "fmt" in kwargs:
|
||||
raise ValueError(
|
||||
"Format string must be passed to level-surrogate formatters, "
|
||||
"not this one"
|
||||
)
|
||||
raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one")
|
||||
|
||||
self.formats = sorted(
|
||||
(level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()
|
||||
)
|
||||
self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items())
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""Overrides ``Formatter.format``."""
|
||||
@@ -110,13 +99,9 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][
|
||||
"WARNING"
|
||||
],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][
|
||||
"CRITICAL"
|
||||
],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ class AccessControlList:
|
||||
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
self.acl: Dict[
|
||||
str, AccessControlList
|
||||
] = {} # A dictionary of ACL Rules
|
||||
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
|
||||
|
||||
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
|
||||
"""
|
||||
@@ -27,29 +25,16 @@ class AccessControlList:
|
||||
True if match; False otherwise.
|
||||
"""
|
||||
if (
|
||||
(
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == "ANY"
|
||||
and _rule.get_dest_ip() == _dest_ip_address
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == _source_ip_address
|
||||
and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
or (
|
||||
_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY"
|
||||
)
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(
|
||||
self, _source_ip_address, _dest_ip_address, _protocol, _port
|
||||
):
|
||||
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
@@ -63,15 +48,9 @@ class AccessControlList:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule_key, rule_value in self.acl.items():
|
||||
if self.check_address_match(
|
||||
rule_value, _source_ip_address, _dest_ip_address
|
||||
):
|
||||
if (
|
||||
rule_value.get_protocol() == _protocol
|
||||
or rule_value.get_protocol() == "ANY"
|
||||
) and (
|
||||
str(rule_value.get_port()) == str(_port)
|
||||
or rule_value.get_port() == "ANY"
|
||||
if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address):
|
||||
if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and (
|
||||
str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule_value.get_permission() == "DENY":
|
||||
@@ -93,9 +72,7 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
new_rule = ACLRule(
|
||||
_permission, _source_ip, _dest_ip, _protocol, str(_port)
|
||||
)
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(new_rule)
|
||||
self.acl[hash_value] = new_rule
|
||||
|
||||
@@ -110,9 +87,7 @@ class AccessControlList:
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule = ACLRule(
|
||||
_permission, _source_ip, _dest_ip, _protocol, str(_port)
|
||||
)
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
# There will not always be something 'popable' since the agent will be trying random things
|
||||
try:
|
||||
@@ -124,9 +99,7 @@ class AccessControlList:
|
||||
"""Removes all rules."""
|
||||
self.acl.clear()
|
||||
|
||||
def get_dictionary_hash(
|
||||
self, _permission, _source_ip, _dest_ip, _protocol, _port
|
||||
):
|
||||
def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port):
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
@@ -140,8 +113,6 @@ class AccessControlList:
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
rule = ACLRule(
|
||||
_permission, _source_ip, _dest_ip, _protocol, str(_port)
|
||||
)
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Dict, Final, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
@@ -51,16 +51,12 @@ class AgentSessionABC(ABC):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path
|
||||
)
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(
|
||||
self._lay_down_config_path
|
||||
)
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.output_verbose_level = self._training_config.output_verbose_level
|
||||
|
||||
self._env: Primaite
|
||||
@@ -132,9 +128,7 @@ class AgentSessionABC(ABC):
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(
|
||||
json_serializable=True
|
||||
),
|
||||
"training_config": self._training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
@@ -161,19 +155,11 @@ class AgentSessionABC(ABC):
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"][
|
||||
"total_episodes"
|
||||
] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"][
|
||||
"total_time_steps"
|
||||
] = self._env.total_step_count # noqa
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"][
|
||||
"total_episodes"
|
||||
] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"][
|
||||
"total_time_steps"
|
||||
] = self._env.total_step_count # noqa
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -184,12 +170,9 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment "
|
||||
f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(
|
||||
f"The output directory for this session is: {self.session_path}"
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
@@ -201,17 +184,11 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
@@ -225,17 +202,11 @@ class AgentSessionABC(ABC):
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
@@ -281,9 +252,7 @@ class AgentSessionABC(ABC):
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = (
|
||||
f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
)
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
pass
|
||||
@@ -354,17 +323,11 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
@@ -375,27 +338,19 @@ class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
|
||||
@@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs):
|
||||
if (
|
||||
self._training_config.hard_coded_agent_view
|
||||
== HardCodedAgentView.BASIC
|
||||
):
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
@@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
):
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
@@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
|
||||
@@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[
|
||||
protocol - 1
|
||||
] # -1 as dont have to account for ANY in list of services
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
|
||||
matching_rules = acl.get_relevant_rules(
|
||||
source_node_address, dest_node_address, protocol, port
|
||||
)
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
@@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(
|
||||
get_node_of_ip(action_source_ip, self._env.nodes)
|
||||
)
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(
|
||||
get_node_of_ip(
|
||||
action_destination_ip, self._env.nodes
|
||||
)
|
||||
)
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name)
|
||||
+ 1
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
@@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(
|
||||
self._env.green_iers, self._env.acl, self._env.nodes
|
||||
)
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(
|
||||
ier, self._env.acl, self._env.nodes
|
||||
)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(
|
||||
service_name_to_check
|
||||
)
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
@@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len(
|
||||
[i for i in o if i != "NONE"]
|
||||
) # number of nodes (not links)
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [
|
||||
n for n, i in enumerate(service_states) if i == "COMPROMISED"
|
||||
]
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = (
|
||||
np.random.choice(comprimised_states) + 1
|
||||
) # +1 as 0 would be any
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(
|
||||
list(range(1, number_of_nodes + 1)) + ["ANY"]
|
||||
)
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip)
|
||||
if action_destination_ip != "ANY"
|
||||
else action_destination_ip
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_action_node_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
@@ -93,12 +89,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = (
|
||||
"ON" # Why reset it when we can just turn it on
|
||||
)
|
||||
action_service_index = (
|
||||
0 # does nothing isn't relevant for operating state
|
||||
)
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
|
||||
@@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.a2c import A2CConfig
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
@@ -14,11 +13,7 @@ from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.common.enums import (
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
)
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -49,10 +44,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = (
|
||||
f"Expected RLLIB agent_framework, "
|
||||
f"got {self._training_config.agent_framework}"
|
||||
)
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
@@ -60,10 +52,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_config_class = A2CConfig
|
||||
else:
|
||||
msg = (
|
||||
"Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier.value}"
|
||||
)
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_config: PPOConfig
|
||||
@@ -94,12 +83,8 @@ class RLlibAgent(AgentSessionABC):
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
metadata_dict["total_episodes"] = self._current_result[
|
||||
"episodes_total"
|
||||
]
|
||||
metadata_dict["total_time_steps"] = self._current_result[
|
||||
"timesteps_total"
|
||||
]
|
||||
metadata_dict["total_episodes"] = self._current_result["episodes_total"]
|
||||
metadata_dict["total_time_steps"] = self._current_result["timesteps_total"]
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
@@ -122,9 +107,7 @@ class RLlibAgent(AgentSessionABC):
|
||||
),
|
||||
)
|
||||
|
||||
self._agent_config.training(
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._agent_config.training(train_batch_size=self._training_config.num_steps)
|
||||
self._agent_config.framework(framework="tf")
|
||||
|
||||
self._agent_config.rollouts(
|
||||
@@ -132,72 +115,41 @@ class RLlibAgent(AgentSessionABC):
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps,
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(
|
||||
logger_creator=_custom_log_creator(self.session_path)
|
||||
)
|
||||
self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (
|
||||
episode_count == self._training_config.num_episodes
|
||||
):
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
# Temporarily override train_batch_size and horizon
|
||||
if time_steps:
|
||||
self._agent_config.train_batch_size = time_steps
|
||||
self._agent_config.horizon = time_steps
|
||||
|
||||
if not episodes:
|
||||
time_steps = self._training_config.num_steps
|
||||
episodes = self._training_config.num_episodes
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps..."
|
||||
)
|
||||
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
if (
|
||||
self._training_config.deep_learning_framework
|
||||
!= DeepLearningFramework.TORCH
|
||||
):
|
||||
policy = self._agent.get_policy()
|
||||
tf.compat.v1.summary.FileWriter(
|
||||
self.session_path / "ray_results", policy.get_session().graph
|
||||
)
|
||||
super().learn()
|
||||
self._agent.stop()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
@@ -21,10 +21,7 @@ class SB3Agent(AgentSessionABC):
|
||||
def __init__(self, training_config_path, lay_down_config_path):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = (
|
||||
f"Expected SB3 agent_framework, "
|
||||
f"got {self._training_config.agent_framework}"
|
||||
)
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
@@ -32,10 +29,7 @@ class SB3Agent(AgentSessionABC):
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = (
|
||||
"Expected PPO or A2C agent_identifier, "
|
||||
f"got {self._training_config.agent_identifier}"
|
||||
)
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -64,19 +58,15 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env,
|
||||
verbose=self.output_verbose_level,
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
)
|
||||
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (episode_count % checkpoint_n == 0) or (
|
||||
episode_count == self._training_config.num_episodes
|
||||
):
|
||||
checkpoint_path = (
|
||||
self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
)
|
||||
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
@@ -85,57 +75,36 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(
|
||||
f"Beginning learning for {episodes} episodes @"
|
||||
f" {time_steps} time steps..."
|
||||
)
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
|
||||
self.close()
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
deterministic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of steps per episode. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param episodes: The number of episodes. Optional. If not
|
||||
passed, the value from the training config will be used.
|
||||
:param deterministic: Whether the evaluation is deterministic.
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
@@ -144,19 +113,18 @@ class SB3Agent(AgentSessionABC):
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for "
|
||||
f"{episodes} episodes @ {time_steps} time steps..."
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(
|
||||
obs, deterministic=deterministic
|
||||
)
|
||||
action, _states = self._agent.predict(obs, deterministic=deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
transform_action_acl_enum,
|
||||
transform_action_node_enum,
|
||||
)
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
|
||||
@@ -24,9 +24,7 @@ def transform_action_node_readable(action):
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (
|
||||
action_node_property == "OS" or action_node_property == "SERVICE"
|
||||
) and action[2] <= 1:
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -117,11 +115,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
@@ -173,9 +167,7 @@ def transform_change_obs_readable(obs):
|
||||
|
||||
for service in range(3, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [
|
||||
SoftwareState(i).name if i <= 4 else i for i in obs[:, service]
|
||||
]
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
@@ -247,9 +239,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(
|
||||
obs1, obs2, num_nodes=10, num_links=10, num_services=1
|
||||
):
|
||||
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
|
||||
"""
|
||||
Return string describing change between two observations.
|
||||
|
||||
@@ -291,16 +281,9 @@ def _describe_obs_change_helper(obs_change, is_link):
|
||||
TODO: Typehint params and return.
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [
|
||||
i for i in range(1, len(obs_change)) if obs_change[i] != -1
|
||||
]
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [
|
||||
NodePOLType(i).name
|
||||
if i < 3
|
||||
else NodePOLType(3).name + " " + str(i - 3)
|
||||
for i in index_changed
|
||||
]
|
||||
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
@@ -367,9 +350,7 @@ def transform_action_node_readable(action):
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (
|
||||
action_node_property == "OS" or action_node_property == "SERVICE"
|
||||
) and action[2] <= 1:
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
@@ -397,9 +378,7 @@ def node_action_description(action):
|
||||
if property_action == "NONE":
|
||||
return ""
|
||||
if node_property == "OPERATING" or node_property == "OS":
|
||||
description = (
|
||||
f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
)
|
||||
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
|
||||
elif node_property == "SERVICE":
|
||||
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
|
||||
else:
|
||||
@@ -522,11 +501,7 @@ def is_valid_acl_action(action):
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if (
|
||||
action_source_id == action_destination_id
|
||||
and action_source_id != "ANY"
|
||||
and action_destination_id != "ANY"
|
||||
):
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
|
||||
@@ -56,9 +56,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
_LogLevel = Enum(
|
||||
"LogLevel", {k: k for k in logging._levelToName.values()}
|
||||
) # noqa
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -124,21 +122,12 @@ def setup(overwrite_existing: bool = True):
|
||||
app_dirs = PlatformDirs(appname="primaite")
|
||||
app_dirs.user_config_path.mkdir(exist_ok=True, parents=True)
|
||||
user_config_path = app_dirs.user_config_path / "primaite_config.yaml"
|
||||
pkg_config_path = Path(
|
||||
pkg_resources.resource_filename(
|
||||
"primaite", "setup/_package_data/primaite_config.yaml"
|
||||
)
|
||||
)
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
|
||||
shutil.copy2(pkg_config_path, user_config_path)
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import (
|
||||
old_installation_clean_up,
|
||||
reset_demo_notebooks,
|
||||
reset_example_configs,
|
||||
setup_app_dirs,
|
||||
)
|
||||
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -188,9 +177,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None):
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(
|
||||
template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None
|
||||
):
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None):
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
@@ -208,14 +195,10 @@ def plotly_template(
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"][
|
||||
"template"
|
||||
] = template.value
|
||||
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
|
||||
with open(user_config_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"][
|
||||
"template"
|
||||
]
|
||||
template = primaite_config["session"]["outputs"]["plots"]["template"]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
|
||||
@@ -8,14 +8,10 @@ from primaite import getLogger, USERS_CONFIG_DIR
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = (
|
||||
USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
)
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config_dict(
|
||||
legacy_config_dict: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy lay down config dict to the new format.
|
||||
|
||||
|
||||
@@ -20,9 +20,7 @@ from primaite.common.enums import (
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = (
|
||||
USERS_CONFIG_DIR / "example_config" / "training"
|
||||
)
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
@@ -68,9 +66,7 @@ class TrainingConfig:
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
|
||||
observation_space: dict = field(
|
||||
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
)
|
||||
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
|
||||
"The observation space config dict"
|
||||
|
||||
time_delay: int = 10
|
||||
@@ -180,9 +176,7 @@ class TrainingConfig:
|
||||
"The time taken to scan the file system"
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, config_dict: Dict[str, Union[str, int, bool]]
|
||||
) -> TrainingConfig:
|
||||
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
@@ -238,9 +232,7 @@ class TrainingConfig:
|
||||
return tc
|
||||
|
||||
|
||||
def load(
|
||||
file_path: Union[str, Path], legacy_file: bool = False
|
||||
) -> TrainingConfig:
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
@@ -271,10 +263,7 @@ def load(
|
||||
try:
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
msg = (
|
||||
f"Error when creating an instance of {TrainingConfig} "
|
||||
f"from the training config file {file_path}"
|
||||
)
|
||||
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
|
||||
_LOGGER.critical(msg, exc_info=True)
|
||||
raise e
|
||||
msg = f"Cannot load the training config as it does not exist: {file_path}"
|
||||
@@ -314,9 +303,7 @@ def convert_legacy_training_config_dict(
|
||||
"output_verbose_level": output_verbose_level.name,
|
||||
}
|
||||
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
|
||||
legacy_config_dict["sessionType"] = session_type_map[
|
||||
legacy_config_dict["sessionType"]
|
||||
]
|
||||
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
|
||||
@@ -77,9 +77,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(
|
||||
observation_shape, dtype=self._DATA_TYPE
|
||||
)
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
@@ -94,12 +92,8 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][
|
||||
2
|
||||
] = node.software_state.value
|
||||
self.current_observation[item_index][
|
||||
3
|
||||
] = node.file_system_state_observed.value
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
@@ -107,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][
|
||||
service_index
|
||||
] = node.get_service_state(service).value
|
||||
self.current_observation[item_index][service_index] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
@@ -129,9 +121,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][
|
||||
protocol_index + 4
|
||||
] = protocol.get_load()
|
||||
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
@@ -203,9 +193,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(
|
||||
service
|
||||
).value
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
@@ -269,11 +257,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = (
|
||||
[self._quantisation_levels]
|
||||
* self.env.num_links
|
||||
* self._entries_per_link
|
||||
)
|
||||
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
@@ -292,9 +276,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [
|
||||
protocol.get_load() for protocol in link.protocol_list
|
||||
]
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
@@ -302,9 +284,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (
|
||||
1 / (self._quantisation_levels - 2)
|
||||
) + 1
|
||||
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
|
||||
@@ -12,13 +12,11 @@ from matplotlib import pyplot as plt
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.agents.utils import (
|
||||
is_valid_acl_action_extra,
|
||||
is_valid_node_action,
|
||||
)
|
||||
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
@@ -37,18 +35,13 @@ from primaite.environment.reward import calculate_reward_function
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node import Node
|
||||
from primaite.nodes.node_state_instruction_green import (
|
||||
NodeStateInstructionGreen,
|
||||
)
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.green_pol import apply_iers, apply_node_pol
|
||||
from primaite.pol.ier import IER
|
||||
from primaite.pol.red_agent_pol import (
|
||||
apply_red_agent_iers,
|
||||
apply_red_agent_node_pol,
|
||||
)
|
||||
from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol
|
||||
from primaite.transactions.transaction import Transaction
|
||||
from primaite.utils.session_output_writer import SessionOutputWriter
|
||||
|
||||
@@ -85,9 +78,7 @@ class Primaite(Env):
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(
|
||||
training_config_path
|
||||
)
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
@@ -238,25 +229,22 @@ class Primaite(Env):
|
||||
self.action_dict = self.create_node_and_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
else:
|
||||
_LOGGER.error(
|
||||
f"Invalid action type selected: {self.training_config.action_type}"
|
||||
)
|
||||
_LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}")
|
||||
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=True
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=True
|
||||
)
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True)
|
||||
|
||||
@property
|
||||
def actual_episode_count(self) -> int:
|
||||
"""Shifts the episode_count by -1 for RLlib."""
|
||||
if self.training_config.agent_framework is AgentFramework.RLLIB:
|
||||
return self.episode_count - 1
|
||||
return self.episode_count
|
||||
|
||||
def set_as_eval(self):
|
||||
"""Set the writers to write to eval directories."""
|
||||
self.episode_av_reward_writer = SessionOutputWriter(
|
||||
self, transaction_writer=False, learning_session=False
|
||||
)
|
||||
self.transaction_writer = SessionOutputWriter(
|
||||
self, transaction_writer=True, learning_session=False
|
||||
)
|
||||
self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False)
|
||||
self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False)
|
||||
self.episode_count = 0
|
||||
self.step_count = 0
|
||||
self.total_step_count = 0
|
||||
@@ -268,8 +256,8 @@ class Primaite(Env):
|
||||
Returns:
|
||||
Environment observation space (reset)
|
||||
"""
|
||||
if self.episode_count > 0:
|
||||
csv_data = self.episode_count, self.average_reward
|
||||
if self.actual_episode_count > 0:
|
||||
csv_data = self.actual_episode_count, self.average_reward
|
||||
self.episode_av_reward_writer.write(csv_data)
|
||||
|
||||
self.episode_count += 1
|
||||
@@ -291,6 +279,7 @@ class Primaite(Env):
|
||||
|
||||
# Update observations space and return
|
||||
self.update_environent_obs()
|
||||
|
||||
return self.env_obs
|
||||
|
||||
def step(self, action):
|
||||
@@ -319,9 +308,7 @@ class Primaite(Env):
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
self.agent_identifier, self.episode_count, self.step_count
|
||||
)
|
||||
transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count)
|
||||
# Load the initial observation space into the transaction
|
||||
transaction.obs_space_pre = copy.deepcopy(self.env_obs)
|
||||
# Load the action space into the transaction
|
||||
@@ -350,9 +337,7 @@ class Primaite(Env):
|
||||
self.nodes_post_pol = copy.deepcopy(self.nodes)
|
||||
self.links_post_pol = copy.deepcopy(self.links)
|
||||
# Reference
|
||||
apply_node_pol(
|
||||
self.nodes_reference, self.node_pol, self.step_count
|
||||
) # Node PoL
|
||||
apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL
|
||||
apply_iers(
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
@@ -371,9 +356,7 @@ class Primaite(Env):
|
||||
self.acl,
|
||||
self.step_count,
|
||||
)
|
||||
apply_red_agent_node_pol(
|
||||
self.nodes, self.red_iers, self.red_node_pol, self.step_count
|
||||
)
|
||||
apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count)
|
||||
# Take snapshots of nodes and links
|
||||
self.nodes_post_red = copy.deepcopy(self.nodes)
|
||||
self.links_post_red = copy.deepcopy(self.links)
|
||||
@@ -389,11 +372,7 @@ class Primaite(Env):
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
_LOGGER.debug(
|
||||
f"Episode: {self.episode_count}, "
|
||||
f"Step {self.step_count}, "
|
||||
f"Reward: {reward}"
|
||||
)
|
||||
_LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -401,10 +380,7 @@ class Primaite(Env):
|
||||
# For evaluation, need to trigger the done value = True when
|
||||
# step count is reached in order to prevent neverending episode
|
||||
done = True
|
||||
_LOGGER.info(
|
||||
f"Episode: {self.episode_count}, "
|
||||
f"Average Reward: {self.average_reward}"
|
||||
)
|
||||
_LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}")
|
||||
# Load the reward into the transaction
|
||||
transaction.reward = reward
|
||||
|
||||
@@ -417,11 +393,21 @@ class Primaite(Env):
|
||||
transaction.obs_space_post = copy.deepcopy(self.env_obs)
|
||||
|
||||
# Write transaction to file
|
||||
if self.actual_episode_count > 0:
|
||||
self.transaction_writer.write(transaction)
|
||||
|
||||
# Return
|
||||
return self.env_obs, reward, done, self.step_info
|
||||
|
||||
def close(self):
|
||||
"""Override parent close and close writers."""
|
||||
# Close files if last episode/step
|
||||
# if self.can_finish:
|
||||
super().close()
|
||||
|
||||
self.transaction_writer.close()
|
||||
self.episode_av_reward_writer.close()
|
||||
|
||||
def init_acl(self):
|
||||
"""Initialise the Access Control List."""
|
||||
self.acl.remove_all_rules()
|
||||
@@ -431,12 +417,7 @@ class Primaite(Env):
|
||||
for link_key, link_value in self.links.items():
|
||||
_LOGGER.debug("Link ID: " + link_value.get_id())
|
||||
for protocol in link_value.protocol_list:
|
||||
_LOGGER.debug(
|
||||
" Protocol: "
|
||||
+ protocol.get_name().name
|
||||
+ ", Load: "
|
||||
+ str(protocol.get_load())
|
||||
)
|
||||
_LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load()))
|
||||
|
||||
def interpret_action_and_apply(self, _action):
|
||||
"""
|
||||
@@ -450,13 +431,9 @@ class Primaite(Env):
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 6
|
||||
): # ACL actions in multidiscrete form have len 6
|
||||
elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6
|
||||
self.apply_actions_to_acl(_action)
|
||||
elif (
|
||||
len(self.action_dict[_action]) == 4
|
||||
): # Node actions in multdiscrete (array) from have len 4
|
||||
elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4
|
||||
self.apply_actions_to_nodes(_action)
|
||||
else:
|
||||
_LOGGER.error("Invalid action type found")
|
||||
@@ -541,10 +518,7 @@ class Primaite(Env):
|
||||
elif property_action == 2:
|
||||
# Repair
|
||||
# You cannot repair a destroyed file system - it needs restoring
|
||||
if (
|
||||
node.file_system_state_actual
|
||||
!= FileSystemState.DESTROYED
|
||||
):
|
||||
if node.file_system_state_actual != FileSystemState.DESTROYED:
|
||||
node.set_file_system_state(FileSystemState.REPAIRING)
|
||||
elif property_action == 3:
|
||||
# Restore
|
||||
@@ -587,9 +561,7 @@ class Primaite(Env):
|
||||
acl_rule_source = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_source_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(
|
||||
node, ActiveNode
|
||||
):
|
||||
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
||||
acl_rule_source = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -598,9 +570,7 @@ class Primaite(Env):
|
||||
acl_rule_destination = "ANY"
|
||||
else:
|
||||
node = list(self.nodes.values())[action_destination_ip - 1]
|
||||
if isinstance(node, ServiceNode) or isinstance(
|
||||
node, ActiveNode
|
||||
):
|
||||
if isinstance(node, ServiceNode) or isinstance(node, ActiveNode):
|
||||
acl_rule_destination = node.ip_address
|
||||
else:
|
||||
return
|
||||
@@ -685,9 +655,7 @@ class Primaite(Env):
|
||||
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||
"""
|
||||
self.obs_handler = ObservationsHandler.from_config(
|
||||
self, self.obs_config
|
||||
)
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
@@ -794,9 +762,7 @@ class Primaite(Env):
|
||||
service_protocol = service["name"]
|
||||
service_port = service["port"]
|
||||
service_state = SoftwareState[service["state"]]
|
||||
node.add_service(
|
||||
Service(service_protocol, service_port, service_state)
|
||||
)
|
||||
node.add_service(Service(service_protocol, service_port, service_state))
|
||||
else:
|
||||
# Bad formatting
|
||||
pass
|
||||
@@ -849,9 +815,7 @@ class Primaite(Env):
|
||||
dest_node_ref: Node = self.nodes_reference[link_destination]
|
||||
|
||||
# Add link to network (reference)
|
||||
self.network_reference.add_edge(
|
||||
source_node_ref, dest_node_ref, id=link_name
|
||||
)
|
||||
self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name)
|
||||
|
||||
# Add link to link dictionary (reference)
|
||||
self.links_reference[link_name] = Link(
|
||||
@@ -1126,9 +1090,7 @@ class Primaite(Env):
|
||||
# All nodes have these parameters
|
||||
node_id = item["node_id"]
|
||||
node_class = item["node_class"]
|
||||
node_hardware_state: HardwareState = HardwareState[
|
||||
item["hardware_state"]
|
||||
]
|
||||
node_hardware_state: HardwareState = HardwareState[item["hardware_state"]]
|
||||
|
||||
node: NodeUnion = self.nodes[node_id]
|
||||
node_ref = self.nodes_reference[node_id]
|
||||
@@ -1249,11 +1211,7 @@ class Primaite(Env):
|
||||
|
||||
# Change node keys to not overlap with acl keys
|
||||
# Only 1 nothing action (key 0) is required, remove the other
|
||||
new_node_action_dict = {
|
||||
k + len(acl_action_dict) - 1: v
|
||||
for k, v in node_action_dict.items()
|
||||
if k != 0
|
||||
}
|
||||
new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0}
|
||||
|
||||
# Combine the Node dict and ACL dict
|
||||
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
|
||||
|
||||
@@ -41,29 +41,19 @@ def calculate_reward_function(
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Hardware State
|
||||
reward_value += score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(
|
||||
final_node, ServiceNode
|
||||
):
|
||||
reward_value += score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Service State
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
)
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
@@ -82,10 +72,7 @@ def calculate_reward_function(
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
)
|
||||
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
reward_value += ier_reward
|
||||
@@ -107,9 +94,7 @@ def calculate_reward_function(
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
@@ -158,9 +143,7 @@ def score_node_operating_state(
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_os_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
@@ -211,9 +194,7 @@ def score_node_os_state(
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_service_state(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
@@ -285,9 +266,7 @@ def score_node_service_state(
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(
|
||||
final_node, initial_node, reference_node, config_values
|
||||
):
|
||||
def score_node_file_system(final_node, initial_node, reference_node, config_values):
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
|
||||
@@ -8,9 +8,7 @@ from primaite.common.protocol import Protocol
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(
|
||||
self, _id, _bandwidth, _source_node_name, _dest_node_name, _services
|
||||
):
|
||||
def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services):
|
||||
"""
|
||||
Init.
|
||||
|
||||
|
||||
@@ -32,11 +32,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--ldc")
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error(
|
||||
"Please provide a training config file using the --tc " "argument"
|
||||
)
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error(
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
@@ -3,13 +3,7 @@
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
@@ -44,9 +38,7 @@ class ActiveNode(Node):
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(
|
||||
node_id, name, node_type, priority, hardware_state, config_values
|
||||
)
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
self.ip_address: str = ip_address
|
||||
# Related to Software
|
||||
self._software_state: SoftwareState = software_state
|
||||
@@ -87,9 +79,7 @@ class ActiveNode(Node):
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(
|
||||
self, software_state: SoftwareState
|
||||
):
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState):
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
@@ -100,9 +90,7 @@ class ActiveNode(Node):
|
||||
if self._software_state != SoftwareState.COMPROMISED:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = (
|
||||
self.config_values.os_patching_duration
|
||||
)
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be changed."
|
||||
@@ -129,14 +117,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
@@ -149,9 +133,7 @@ class ActiveNode(Node):
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(
|
||||
self, file_system_state: FileSystemState
|
||||
):
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState):
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
@@ -168,14 +150,10 @@ class ActiveNode(Node):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_repairing_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = (
|
||||
self.config_values.file_system_restoring_limit
|
||||
)
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
@@ -191,9 +169,7 @@ class ActiveNode(Node):
|
||||
def start_file_system_scan(self):
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = (
|
||||
self.config_values.file_system_scanning_limit
|
||||
)
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self):
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
@@ -212,10 +188,7 @@ class ActiveNode(Node):
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
# Scanning updates
|
||||
if (
|
||||
self.file_system_scanning == True
|
||||
and self.file_system_scanning_count < 0
|
||||
):
|
||||
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
|
||||
self.file_system_state_observed = self.file_system_state_actual
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
@@ -32,9 +32,7 @@ class NodeStateInstructionGreen(object):
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type = _node_pol_type
|
||||
self.service_name = (
|
||||
_service_name # Not used when not a service instruction
|
||||
)
|
||||
self.service_name = _service_name # Not used when not a service instruction
|
||||
self.state = _state
|
||||
|
||||
def get_start_step(self):
|
||||
|
||||
@@ -42,9 +42,7 @@ class NodeStateInstructionRed(object):
|
||||
self.target_node_id = _target_node_id
|
||||
self.initiator = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name = (
|
||||
pol_protocol # Not used when not a service instruction
|
||||
)
|
||||
self.service_name = pol_protocol # Not used when not a service instruction
|
||||
self.state = _pol_state
|
||||
self.source_node_id = _pol_source_node_id
|
||||
self.source_node_service = _pol_source_node_service
|
||||
|
||||
@@ -28,9 +28,7 @@ class PassiveNode(Node):
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
# Pass through to Super for now
|
||||
super().__init__(
|
||||
node_id, name, node_type, priority, hardware_state, config_values
|
||||
)
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
|
||||
@@ -3,13 +3,7 @@
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -110,9 +104,7 @@ class ServiceNode(ActiveNode):
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(
|
||||
self, protocol_name: str, software_state: SoftwareState
|
||||
):
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -130,9 +122,7 @@ class ServiceNode(ActiveNode):
|
||||
) or software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
@@ -143,9 +133,7 @@ class ServiceNode(ActiveNode):
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(
|
||||
self, protocol_name: str, software_state: SoftwareState
|
||||
):
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState):
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
@@ -161,9 +149,7 @@ class ServiceNode(ActiveNode):
|
||||
if service_value.software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = (
|
||||
self.config_values.service_patching_duration
|
||||
)
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
|
||||
@@ -6,17 +6,10 @@ from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import (
|
||||
NodeStateInstructionGreen,
|
||||
)
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
@@ -93,9 +86,7 @@ def apply_iers(
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(
|
||||
protocol
|
||||
) and not source_node.service_is_overwhelmed(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -110,10 +101,7 @@ def apply_iers(
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
dest_node.hardware_state == HardwareState.ON
|
||||
and dest_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
@@ -123,14 +111,9 @@ def apply_iers(
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if (
|
||||
dest_node.hardware_state == HardwareState.ON
|
||||
and dest_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(
|
||||
protocol
|
||||
) and not dest_node.service_is_overwhelmed(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
@@ -143,9 +126,7 @@ def apply_iers(
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.ip_address, dest_node.ip_address, protocol, port
|
||||
)
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
@@ -176,10 +157,7 @@ def apply_iers(
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if (
|
||||
node.hardware_state != HardwareState.ON
|
||||
or node.software_state == SoftwareState.PATCHING
|
||||
):
|
||||
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
@@ -191,15 +169,11 @@ def apply_iers(
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (
|
||||
link.get_current_load() + load
|
||||
) > link.get_bandwidth():
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
@@ -226,9 +200,7 @@ def apply_iers(
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Path not valid - one or more nodes not operational"
|
||||
)
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
@@ -243,9 +215,7 @@ def apply_iers(
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[
|
||||
any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]
|
||||
],
|
||||
node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]],
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
@@ -277,22 +247,16 @@ def apply_node_pol(
|
||||
elif node_pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ActiveNode) or isinstance(
|
||||
node, ServiceNode
|
||||
):
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_software_state_if_not_compromised(state)
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ServiceNode):
|
||||
node.set_service_state_if_not_compromised(
|
||||
service_name, state
|
||||
)
|
||||
node.set_service_state_if_not_compromised(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(node, ActiveNode) or isinstance(
|
||||
node, ServiceNode
|
||||
):
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
|
||||
@@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
NodePOLInitiator,
|
||||
NodePOLType,
|
||||
NodeType,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
@@ -83,10 +77,7 @@ def apply_red_agent_iers(
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if (
|
||||
source_node.get_service_state(protocol)
|
||||
== SoftwareState.COMPROMISED
|
||||
):
|
||||
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
@@ -124,9 +115,7 @@ def apply_red_agent_iers(
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(
|
||||
source_node.ip_address, dest_node.ip_address, protocol, port
|
||||
)
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
@@ -170,15 +159,11 @@ def apply_red_agent_iers(
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count], path_node_list[count + 1]
|
||||
)
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (
|
||||
link.get_current_load() + load
|
||||
) > link.get_bandwidth():
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
@@ -203,23 +188,16 @@ def apply_red_agent_iers(
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Red IER was allowed to run in step "
|
||||
+ str(step)
|
||||
)
|
||||
print("Red IER was allowed to run in step " + str(step))
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Path not valid - one or more nodes not operational"
|
||||
)
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Red IER was NOT allowed to run in step " + str(step)
|
||||
)
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
@@ -258,9 +236,7 @@ def apply_red_agent_node_pol(
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = (
|
||||
node_instruction.get_source_node_service_state()
|
||||
)
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
|
||||
passed_checks = False
|
||||
|
||||
@@ -274,9 +250,7 @@ def apply_red_agent_node_pol(
|
||||
passed_checks = True
|
||||
elif initiator == NodePOLInitiator.IER:
|
||||
# Need to check there is a red IER incoming
|
||||
passed_checks = is_red_ier_incoming(
|
||||
target_node, iers, pol_type
|
||||
)
|
||||
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
|
||||
elif initiator == NodePOLInitiator.SERVICE:
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
@@ -304,9 +278,7 @@ def apply_red_agent_node_pol(
|
||||
target_node.hardware_state = state
|
||||
elif pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.software_state = state
|
||||
elif pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
@@ -314,15 +286,11 @@ def apply_red_agent_node_pol(
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(
|
||||
target_node, ServiceNode
|
||||
):
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"Node Red Agent PoL not allowed - did not pass checks"
|
||||
)
|
||||
print("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
@@ -337,10 +305,7 @@ def is_red_ier_incoming(node, iers, node_pol_type):
|
||||
node_id = node.node_id
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if (
|
||||
ier_value.get_is_running()
|
||||
and ier_value.get_dest_node_id() == node_id
|
||||
):
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if (
|
||||
node_pol_type == NodePOLType.OPERATING
|
||||
or node_pol_type == NodePOLType.OS
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Optional, Union
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
@@ -9,18 +9,8 @@ from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import (
|
||||
DoNothingACLAgent,
|
||||
DoNothingNodeAgent,
|
||||
DummyAgent,
|
||||
RandomAgent,
|
||||
)
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
SessionType,
|
||||
)
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
@@ -49,16 +39,12 @@ class PrimaiteSession:
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path
|
||||
)
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(
|
||||
self._lay_down_config_path
|
||||
)
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = None # noqa
|
||||
@@ -69,28 +55,16 @@ class PrimaiteSession:
|
||||
def setup(self):
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}"
|
||||
)
|
||||
if (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.HARDCODED
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.HARDCODED}"
|
||||
)
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -100,24 +74,14 @@ class PrimaiteSession:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.DO_NOTHING
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.DO_NOTHINGD}"
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHINGD}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -127,49 +91,26 @@ class PrimaiteSession:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif (
|
||||
self._training_config.agent_identifier
|
||||
== AgentIdentifier.RANDOM
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.RANDOM}"
|
||||
)
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
elif (
|
||||
self._training_config.agent_identifier == AgentIdentifier.DUMMY
|
||||
):
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Identifier ="
|
||||
f" {AgentIdentifier.DUMMY}"
|
||||
)
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}"
|
||||
)
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(
|
||||
f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}"
|
||||
)
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(
|
||||
self._training_config_path, self._lay_down_config_path
|
||||
)
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
@@ -182,35 +123,27 @@ class PrimaiteSession:
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param time_steps: The number of time steps per episode.
|
||||
:param episodes: The number of episodes.
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(time_steps, episodes, **kwargs)
|
||||
self._agent_session.learn(**kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param time_steps: The number of time steps per episode.
|
||||
:param episodes: The number of episodes.
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(time_steps, episodes, **kwargs)
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Closes the agent."""
|
||||
|
||||
@@ -18,23 +18,17 @@ def run(overwrite_existing: bool = True):
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
notebooks on or off.
|
||||
"""
|
||||
notebooks_package_data_root = pkg_resources.resource_filename(
|
||||
"primaite", "notebooks/_package_data"
|
||||
)
|
||||
notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data")
|
||||
for subdir, dirs, files in os.walk(notebooks_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(
|
||||
fp, notebooks_package_data_root
|
||||
).split(os.sep)
|
||||
path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep)
|
||||
target_fp = NOTEBOOKS_DIR / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
if overwrite_existing and not copy_file:
|
||||
copy_file = (not filecmp.cmp(fp, target_fp)) and (
|
||||
".ipynb_checkpoints" not in str(target_fp)
|
||||
)
|
||||
copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp))
|
||||
|
||||
if copy_file:
|
||||
shutil.copy2(fp, target_fp)
|
||||
|
||||
@@ -17,16 +17,12 @@ def run(overwrite_existing=True):
|
||||
:param overwrite_existing: A bool to toggle replacing existing edited
|
||||
config on or off.
|
||||
"""
|
||||
configs_package_data_root = pkg_resources.resource_filename(
|
||||
"primaite", "config/_package_data"
|
||||
)
|
||||
configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data")
|
||||
|
||||
for subdir, dirs, files in os.walk(configs_package_data_root):
|
||||
for file in files:
|
||||
fp = os.path.join(subdir, file)
|
||||
path_split = os.path.relpath(fp, configs_package_data_root).split(
|
||||
os.sep
|
||||
)
|
||||
path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep)
|
||||
target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split)
|
||||
target_fp.parent.mkdir(exist_ok=True, parents=True)
|
||||
copy_file = not target_fp.is_file()
|
||||
|
||||
@@ -76,12 +76,8 @@ class Transaction(object):
|
||||
row = (
|
||||
row
|
||||
+ _turn_action_space_to_array(self.action_space)
|
||||
+ _turn_obs_space_to_array(
|
||||
self.obs_space_pre, obs_assets, obs_features
|
||||
)
|
||||
+ _turn_obs_space_to_array(
|
||||
self.obs_space_post, obs_assets, obs_features
|
||||
)
|
||||
+ _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features)
|
||||
+ _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features)
|
||||
)
|
||||
return header, row
|
||||
|
||||
|
||||
@@ -51,9 +51,7 @@ class SessionOutputWriter:
|
||||
self._first_write: bool = True
|
||||
|
||||
def _init_csv_writer(self):
|
||||
self._csv_file = open(
|
||||
self._csv_file_path, "w", encoding="UTF8", newline=""
|
||||
)
|
||||
self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="")
|
||||
|
||||
self._csv_writer = csv.writer(self._csv_file)
|
||||
|
||||
|
||||
@@ -57,8 +57,6 @@ class TempPrimaiteSession(PrimaiteSession):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
del self._agent_session._env.episode_av_reward_writer
|
||||
del self._agent_session._env.transaction_writer
|
||||
shutil.rmtree(self.session_path)
|
||||
shutil.rmtree(self.session_path.parent)
|
||||
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
@@ -112,9 +110,7 @@ def temp_primaite_session(request):
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch(
|
||||
"primaite.agents.agent.get_session_path", get_temp_session_path
|
||||
) as mck:
|
||||
with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
@@ -130,9 +126,7 @@ def temp_session_path() -> Path:
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = (
|
||||
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
)
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
@@ -16,9 +16,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = (
|
||||
Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
)
|
||||
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
_LOGGER.debug(f"Created temp session directory: {session_path}")
|
||||
return session_path
|
||||
|
||||
@@ -95,8 +95,6 @@ def test_rule_hash():
|
||||
rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
hash_value_local = hash(rule)
|
||||
|
||||
hash_value_remote = acl.get_dictionary_hash(
|
||||
"DENY", "192.168.1.1", "192.168.1.2", "TCP", "80"
|
||||
)
|
||||
hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80")
|
||||
|
||||
assert hash_value_local == hash_value_remote
|
||||
|
||||
@@ -60,9 +60,7 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state):
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_software_state_if_not_compromised(
|
||||
SoftwareState.OVERWHELMED
|
||||
)
|
||||
active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED)
|
||||
|
||||
assert active_node.software_state == expected_state
|
||||
|
||||
@@ -100,9 +98,7 @@ def test_file_system_change(operating_state, expected_state):
|
||||
(HardwareState.ON, FileSystemState.CORRUPT),
|
||||
],
|
||||
)
|
||||
def test_file_system_change_if_not_compromised(
|
||||
operating_state, expected_state
|
||||
):
|
||||
def test_file_system_change_if_not_compromised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change its file system state.
|
||||
|
||||
@@ -120,8 +116,6 @@ def test_file_system_change_if_not_compromised(
|
||||
1,
|
||||
)
|
||||
|
||||
active_node.set_file_system_state_if_not_compromised(
|
||||
FileSystemState.CORRUPT
|
||||
)
|
||||
active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT)
|
||||
|
||||
assert active_node.file_system_state_actual == expected_state
|
||||
|
||||
@@ -2,11 +2,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from primaite.environment.observations import (
|
||||
NodeLinkTable,
|
||||
NodeStatuses,
|
||||
ObservationsHandler,
|
||||
)
|
||||
from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
|
||||
@@ -127,9 +123,7 @@ class TestNodeLinkTable:
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
# act = np.asarray([0,])
|
||||
obs, reward, done, info = env.step(
|
||||
0
|
||||
) # apply the 'do nothing' action
|
||||
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||
|
||||
assert np.array_equal(
|
||||
obs,
|
||||
@@ -192,17 +186,15 @@ class TestNodeStatuses:
|
||||
with temp_primaite_session as session:
|
||||
env = session.env
|
||||
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||
assert np.array_equal(
|
||||
obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]
|
||||
)
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||
TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
]
|
||||
],
|
||||
|
||||
@@ -36,18 +36,12 @@ def test_primaite_session(temp_primaite_session):
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.learning_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
# Check that both the transactions and av reward csv files exist
|
||||
for file in session.evaluation_path.iterdir():
|
||||
if file.suffix == ".csv":
|
||||
assert (
|
||||
"all_transactions" in file.name
|
||||
or "average_reward_per_episode" in file.name
|
||||
)
|
||||
assert "all_transactions" in file.name or "average_reward_per_episode" in file.name
|
||||
|
||||
_LOGGER.debug("Inspecting files in temp session path...")
|
||||
for dir_path, dir_names, file_names in os.walk(session_path):
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -18,9 +12,7 @@ from primaite.nodes.service_node import ServiceNode
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(
|
||||
starting_operating_state, expected_operating_state
|
||||
):
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id="0",
|
||||
@@ -59,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(
|
||||
name="node", port="80", software_state=SoftwareState.COMPROMISED
|
||||
)
|
||||
service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
|
||||
@@ -45,9 +45,7 @@ def test_service_state_change(operating_state, expected_state):
|
||||
(HardwareState.ON, SoftwareState.OVERWHELMED),
|
||||
],
|
||||
)
|
||||
def test_service_state_change_if_not_comprised(
|
||||
operating_state, expected_state
|
||||
):
|
||||
def test_service_state_change_if_not_comprised(operating_state, expected_state):
|
||||
"""
|
||||
Test that a node cannot change the state of a running service.
|
||||
|
||||
@@ -67,8 +65,6 @@ def test_service_state_change_if_not_comprised(
|
||||
service = Service("TCP", 80, SoftwareState.GOOD)
|
||||
service_node.add_service(service)
|
||||
|
||||
service_node.set_service_state_if_not_compromised(
|
||||
"TCP", SoftwareState.OVERWHELMED
|
||||
)
|
||||
service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED)
|
||||
|
||||
assert service_node.get_service_state("TCP") == expected_state
|
||||
|
||||
@@ -18,7 +18,6 @@ def run_generic_set_actions(env: Primaite):
|
||||
# TEMP - random action for now
|
||||
# action = env.blue_agent_action(obs)
|
||||
action = 0
|
||||
print("Episode:", episode, "\nStep:", step)
|
||||
if step == 5:
|
||||
# [1, 1, 2, 1, 1, 1]
|
||||
# Creates an ACL rule
|
||||
@@ -86,8 +85,7 @@ def test_single_action_space_is_valid(temp_primaite_session):
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[
|
||||
TEST_CONFIG_ROOT
|
||||
/ "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
|
||||
TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
|
||||
]
|
||||
],
|
||||
|
||||
@@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT
|
||||
|
||||
def test_legacy_lay_down_config_yaml_conversion():
|
||||
"""Tests the conversion of legacy lay down config files."""
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
|
||||
legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
with open(legacy_path, "r") as file:
|
||||
legacy_dict = yaml.safe_load(file)
|
||||
@@ -16,9 +16,7 @@ def test_legacy_lay_down_config_yaml_conversion():
|
||||
with open(new_path, "r") as file:
|
||||
new_dict = yaml.safe_load(file)
|
||||
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(
|
||||
legacy_dict
|
||||
)
|
||||
converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict)
|
||||
|
||||
for key, value in new_dict.items():
|
||||
assert converted_dict[key] == value
|
||||
@@ -26,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion():
|
||||
|
||||
def test_create_config_values_main_from_file():
|
||||
"""Tests creating an instance of TrainingConfig from file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml"
|
||||
|
||||
training_config.load(new_path)
|
||||
|
||||
|
||||
def test_create_config_values_main_from_legacy_file():
|
||||
"""Tests creating an instance of TrainingConfig from legacy file."""
|
||||
new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml"
|
||||
new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
|
||||
training_config.load(new_path, legacy_file=True)
|
||||
|
||||
Reference in New Issue
Block a user