Remove deprecated code from v2
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Access Control List. Models firewall functionality."""
|
||||
@@ -1,198 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
import logging
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
# Maximum number of ACL Rules in ACL
|
||||
self.max_acl_rules: int = max_acl_rules
|
||||
# A list of ACL Rules
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
|
||||
:param _rule: The rule object to check
|
||||
:type _rule: ACLRule
|
||||
:param _source_ip_address: Source IP address to compare
|
||||
:type _source_ip_address: str
|
||||
:param _dest_ip_address: Destination IP address to compare
|
||||
:type _dest_ip_address: str
|
||||
:return: True if there is a match, otherwise False.
|
||||
:rtype: bool
|
||||
"""
|
||||
if (
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool:
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
_dest_ip_address: the destination IP address to check
|
||||
_protocol: the protocol to check
|
||||
_port: the port to check
|
||||
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule in self.acl:
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(
|
||||
self,
|
||||
_permission: RulePermissionType,
|
||||
_source_ip: str,
|
||||
_dest_ip: str,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_position: str,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
|
||||
"""
|
||||
try:
|
||||
position_index = int(_position)
|
||||
except TypeError:
|
||||
_LOGGER.info(f"Position {_position} could not be converted to integer.")
|
||||
return
|
||||
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
delete_rule_hash = hash(rule_to_delete)
|
||||
|
||||
for index in range(0, len(self._acl)):
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> int:
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
def get_relevant_rules(
|
||||
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all ACL rules that relate to the given arguments.
|
||||
|
||||
:param _source_ip_address: the source IP address to check
|
||||
:param _dest_ip_address: the destination IP address to check
|
||||
:param _protocol: the protocol to check
|
||||
:param _port: the port to check
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules = {}
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
|
||||
):
|
||||
# There's a matching rule.
|
||||
relevant_rules[self._acl.index(rule)] = rule
|
||||
|
||||
return relevant_rules
|
||||
@@ -1,87 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
:param _permission: The permission (ALLOW or DENY)
|
||||
:param _source_ip: The source IP address
|
||||
:param _dest_ip: The destination IP address
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
return hash(
|
||||
(
|
||||
self.permission,
|
||||
self.source_ip,
|
||||
self.dest_ip,
|
||||
self.protocol,
|
||||
self.port,
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self) -> str:
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
"""
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self) -> str:
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
"""
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self) -> str:
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
"""
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
"""
|
||||
return self.port
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Common interface between RL agents from different libraries and PrimAITE."""
|
||||
@@ -1,319 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
"""
|
||||
An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
If training configuration and laydown configuration are provided with a session path,
|
||||
the session path will be used.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
# initialise variables
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
|
||||
# convert session to path
|
||||
if session_path is not None:
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
# load session
|
||||
self.load(session_path)
|
||||
else:
|
||||
# set training config path
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(
|
||||
self._training_config_path, legacy_file=legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = str(uuid4())
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
"""The session timestamp as a string."""
|
||||
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@property
|
||||
def learning_path(self) -> Path:
|
||||
"""The learning outputs path."""
|
||||
path = self.session_path / "learning"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def evaluation_path(self) -> Path:
|
||||
"""The evaluation outputs path."""
|
||||
path = self.session_path / "evaluation"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def checkpoints_path(self) -> Path:
|
||||
"""The Session checkpoints path."""
|
||||
path = self.learning_path / "checkpoints"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self) -> str:
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_path`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self) -> None:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
_LOGGER.info("Finished learning")
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_evaluate:
|
||||
self._update_session_metadata_file()
|
||||
self.is_eval = True
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load an agent from file."""
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = laydown_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
|
||||
if learning_session:
|
||||
title += "(Learning)"
|
||||
path = self.learning_path / csv_file
|
||||
image_path = self.learning_path / image_file
|
||||
else:
|
||||
title += "(Evaluation)"
|
||||
path = self.evaluation_path / csv_file
|
||||
image_path = self.evaluation_path / image_file
|
||||
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
@@ -1,118 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path] = None) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
@@ -1,515 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
transform_action_acl_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
# full view action using observation space, action
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, IER]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
:type green_iers: Dict[str, IER]
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
|
||||
:rtype: Dict[str, IER]
|
||||
"""
|
||||
blocked_green_iers = {}
|
||||
|
||||
for green_ier_id, green_ier in green_iers.items():
|
||||
source_node_id = green_ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = green_ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = green_ier.get_protocol() # e.g. 'TCP'
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
(No ALLOW rule, therefore blocked).
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
blocked_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
blocked_rules[rule_key] = rule_value
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_matching_acl_rules(
|
||||
self,
|
||||
source_node_id: str,
|
||||
dest_node_id: str,
|
||||
protocol: str,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
:param source_node_id: Source node
|
||||
:type source_node_id: str
|
||||
:param dest_node_id: Destination nodes
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: str
|
||||
:param port: Network port
|
||||
:type port: str
|
||||
:param acl: Access Control list which will be filtered
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The environment's node directory.
|
||||
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
|
||||
:param services_list: List of services registered for the environment.
|
||||
:type services_list: List[str]
|
||||
:return: Filtered version of 'acl'
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
else:
|
||||
source_node_address = source_node_id
|
||||
|
||||
if dest_node_id != "ANY":
|
||||
dest_node_address = nodes[str(dest_node_id)].ip_address
|
||||
else:
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
# TODO: This should throw an error because protocol is a string
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
self,
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_deny_acl_rules(
|
||||
self,
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
- Which ACL rules already exist, - otherwise:
|
||||
- The agent would perminently get stuck in a loop of performing the same action over and over.
|
||||
(best action is to block something, but its already blocked but doesn't know this)
|
||||
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
|
||||
if it doesnt know what rules exist)
|
||||
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
|
||||
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
|
||||
based on the observation space. Additionally, potentially in the future, once a node state
|
||||
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
|
||||
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
|
||||
|
||||
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
|
||||
|
||||
If a service node becomes compromised there's a decision to make - do we block that service?
|
||||
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
|
||||
Cons: Will block a green IER, decreasing the reward
|
||||
We decide to block the service.
|
||||
|
||||
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
# obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, _, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# 1. Check if node is compromised. If so we want to block its outwards services
|
||||
# a. If it is comprimised check if there's an allow rule we should delete.
|
||||
# cons: might delete a multi-rule from any source node (ANY -> x)
|
||||
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
|
||||
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
|
||||
found_action = False
|
||||
for service_num, service_states in enumerate(s):
|
||||
for x, service_state in enumerate(service_states):
|
||||
if service_state == "COMPROMISED":
|
||||
action_source_id = x + 1 # +1 as 0 is any
|
||||
action_destination_id = "ANY"
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
action_port = "ANY"
|
||||
|
||||
allow_rules = self.get_allow_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
deny_rules = self.get_deny_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
if len(allow_rules) > 0:
|
||||
# Check if there's an allow rule we should delete
|
||||
rule = list(allow_rules.values())[0]
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
elif len(deny_rules) > 0:
|
||||
# TODO OPTIONAL
|
||||
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
|
||||
# to create another
|
||||
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
|
||||
continue
|
||||
else:
|
||||
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
|
||||
action_decision = "CREATE"
|
||||
action_permission = "DENY"
|
||||
break
|
||||
if found_action:
|
||||
break
|
||||
|
||||
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
|
||||
# add an Allow rule if the green IER is being blocked.
|
||||
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
|
||||
# If there's a DENY rule delete it if:
|
||||
# - There isn't already a deny rule
|
||||
# - It doesnt allows a comprimised node to become operational.
|
||||
# b. Add an ALLOW rule if:
|
||||
# - There isn't already an allow rule
|
||||
# - It doesnt allows a comprimised node to become operational
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
|
||||
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
|
||||
action_decision = "CREATE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_id = int(ier.get_source_node_id())
|
||||
action_destination_id = int(ier.get_dest_node_id())
|
||||
action_protocol_name = ier.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = ier.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
|
||||
if found_action:
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
else:
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
of previous actions taken and NO reward feedback.
|
||||
|
||||
We rely on randomness to select the precise action, as we want to
|
||||
block all traffic originating from a compromised node, without being
|
||||
able to tell:
|
||||
1. Which ACL rules already exist
|
||||
2. Which actions the agent has already tried.
|
||||
|
||||
There is a high probability that the correct rule will not be deleted
|
||||
before the state becomes overwhelmed.
|
||||
|
||||
Currently, a deny rule does not overwrite an allow rule. The allow
|
||||
rules must be deleted.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
# Bad assumption that number of protocols equals number of ports
|
||||
# AND no rules exist with an ANY port
|
||||
action_port = np.random.choice(list(range(1, len(s) + 1)))
|
||||
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_ip,
|
||||
action_destination_ip,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, action_dict)
|
||||
return nothing_action
|
||||
@@ -1,125 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, os, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# Check in order of most important states (order doesn't currently
|
||||
# matter, but it probably should)
|
||||
# First see if any OS states are compromised
|
||||
for x, os_state in enumerate(os):
|
||||
if os_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OS"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = 0 # does nothing isn't relevant for os
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, see if any Services are compromised
|
||||
# We fix the compromised state before overwhelemd state,
|
||||
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, See if any services are overwhelmed
|
||||
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
|
||||
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
|
||||
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "OVERWHELMED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Finally, turn on any off nodes
|
||||
for x, operating_state in enumerate(o):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good actions, just go with an action that wont do any harm
|
||||
action_node_id = 1
|
||||
action_node_property = "NONE"
|
||||
property_action = "ON"
|
||||
action_service_index = 0
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
|
||||
return action
|
||||
@@ -1,287 +0,0 @@
|
||||
# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# from __future__ import annotations
|
||||
|
||||
# import json
|
||||
# import shutil
|
||||
# import zipfile
|
||||
# from datetime import datetime
|
||||
# from logging import Logger
|
||||
# from pathlib import Path
|
||||
# from typing import Any, Callable, Dict, Optional, Union
|
||||
# from uuid import uuid4
|
||||
|
||||
# from primaite import getLogger
|
||||
# from primaite.agents.agent_abc import AgentSessionABC
|
||||
# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
|
||||
# # from ray.rllib.algorithms import Algorithm
|
||||
# # from ray.rllib.algorithms.a2c import A2CConfig
|
||||
# # from ray.rllib.algorithms.ppo import PPOConfig
|
||||
# # from ray.tune.logger import UnifiedLogger
|
||||
# # from ray.tune.registry import register_env
|
||||
|
||||
|
||||
# # from primaite.exceptions import RLlibAgentError
|
||||
|
||||
# _LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
# # TODO: verify type of env_config
|
||||
# def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
# return Primaite(
|
||||
# training_config_path=env_config["training_config_path"],
|
||||
# lay_down_config_path=env_config["lay_down_config_path"],
|
||||
# session_path=env_config["session_path"],
|
||||
# timestamp_str=env_config["timestamp_str"],
|
||||
# )
|
||||
|
||||
# # # TODO: verify type hint return type
|
||||
# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
# # logdir = session_path / "ray_results"
|
||||
# # logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# # def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
# # return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
# return logger_creator
|
||||
|
||||
|
||||
# # class RLlibAgent(AgentSessionABC):
|
||||
# # """An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
# # def __init__(
|
||||
# # self,
|
||||
# # training_config_path: Optional[Union[str, Path]] = "",
|
||||
# # lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
# # session_path: Optional[Union[str, Path]] = None,
|
||||
# # ) -> None:
|
||||
# # """
|
||||
# # Initialise the RLLib Agent training session.
|
||||
|
||||
# # :param training_config_path: YAML file containing configurable items defined in
|
||||
# # `primaite.config.training_config.TrainingConfig`
|
||||
# # :type training_config_path: Union[path, str]
|
||||
# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
# # :type lay_down_config_path: Union[path, str]
|
||||
# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB")
|
||||
# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO`
|
||||
# # or `A2C`)
|
||||
# # """
|
||||
# # # TODO: implement RLlib agent loading
|
||||
# # if session_path is not None:
|
||||
# # msg = "RLlib agent loading has not been implemented yet"
|
||||
# # _LOGGER.critical(msg)
|
||||
# # raise NotImplementedError(msg)
|
||||
|
||||
# # super().__init__(training_config_path, lay_down_config_path)
|
||||
# # if self._training_config.session_type == SessionType.EVAL:
|
||||
# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
|
||||
# # _LOGGER.critical(msg)
|
||||
# # raise RLlibAgentError(msg)
|
||||
# # if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
# # _LOGGER.error(msg)
|
||||
# # raise ValueError(msg)
|
||||
# # self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
# # if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
# # self._agent_config_class = PPOConfig
|
||||
# # elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
# # self._agent_config_class = A2CConfig
|
||||
# # else:
|
||||
# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
# # _LOGGER.error(msg)
|
||||
# # raise ValueError(msg)
|
||||
# # self._agent_config: Union[PPOConfig, A2CConfig]
|
||||
|
||||
# # self._current_result: dict
|
||||
# # self._setup()
|
||||
# # _LOGGER.debug(
|
||||
# # f"Created {self.__class__.__name__} using: "
|
||||
# # f"agent_framework={self._training_config.agent_framework}, "
|
||||
# # f"agent_identifier="
|
||||
# # f"{self._training_config.agent_identifier}, "
|
||||
# # f"deep_learning_framework="
|
||||
# # f"{self._training_config.deep_learning_framework}"
|
||||
# # )
|
||||
# # self._train_agent = None # Required to capture the learning agent to close after eval
|
||||
|
||||
# # def _update_session_metadata_file(self) -> None:
|
||||
# # """
|
||||
# # Update the ``session_metadata.json`` file.
|
||||
|
||||
# # Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
# # with the following key/value pairs:
|
||||
|
||||
# # - end_datetime: The date & time the session ended in iso format.
|
||||
# # - total_episodes: The total number of training episodes completed.
|
||||
# # - total_time_steps: The total number of training time steps completed.
|
||||
# # """
|
||||
# # with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
# # metadata_dict = json.load(file)
|
||||
|
||||
# # metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
# # if not self.is_eval:
|
||||
# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
# # else:
|
||||
# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
|
||||
# # filepath = self.session_path / "session_metadata.json"
|
||||
# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
# # with open(filepath, "w") as file:
|
||||
# # json.dump(metadata_dict, file)
|
||||
# # _LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
# # def _setup(self) -> None:
|
||||
# # super()._setup()
|
||||
# # register_env("primaite", _env_creator)
|
||||
# # self._agent_config = self._agent_config_class()
|
||||
|
||||
# # self._agent_config.environment(
|
||||
# # env="primaite",
|
||||
# # env_config=dict(
|
||||
# # training_config_path=self._training_config_path,
|
||||
# # lay_down_config_path=self._lay_down_config_path,
|
||||
# # session_path=self.session_path,
|
||||
# # timestamp_str=self.timestamp_str,
|
||||
# # ),
|
||||
# # )
|
||||
# # self._agent_config.seed = self._training_config.seed
|
||||
|
||||
# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
|
||||
# # self._agent_config.framework(framework="tf")
|
||||
|
||||
# # self._agent_config.rollouts(
|
||||
# # num_rollout_workers=1,
|
||||
# # num_envs_per_worker=1,
|
||||
# # horizon=self._training_config.num_train_steps,
|
||||
# # )
|
||||
# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
# # def _save_checkpoint(self) -> None:
|
||||
# # checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
# # episode_count = self._current_result["episodes_total"]
|
||||
# # save_checkpoint = False
|
||||
# # if checkpoint_n:
|
||||
# # save_checkpoint = episode_count % checkpoint_n == 0
|
||||
# # if episode_count and save_checkpoint:
|
||||
# # self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
# # def learn(
|
||||
# # self,
|
||||
# # **kwargs: Any,
|
||||
# # ) -> None:
|
||||
# # """
|
||||
# # Evaluate the agent.
|
||||
|
||||
# # :param kwargs: Any agent-specific key-word args to be passed.
|
||||
# # """
|
||||
# # time_steps = self._training_config.num_train_steps
|
||||
# # episodes = self._training_config.num_train_episodes
|
||||
|
||||
# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
# # for i in range(episodes):
|
||||
# # self._current_result = self._agent.train()
|
||||
# # self._save_checkpoint()
|
||||
# # self.save()
|
||||
# # super().learn()
|
||||
# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped
|
||||
# # if self._training_config.session_type is not SessionType.TRAIN:
|
||||
# # self._train_agent = self._agent
|
||||
# # else:
|
||||
# # self._agent.stop()
|
||||
# # self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
# # def _unpack_saved_agent_into_eval(self) -> Path:
|
||||
# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
|
||||
# # agent_restore_path = self.evaluation_path / "agent_restore"
|
||||
# # if agent_restore_path.exists():
|
||||
# # shutil.rmtree(agent_restore_path)
|
||||
# # agent_restore_path.mkdir()
|
||||
# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
|
||||
# # zip_file.extractall(agent_restore_path)
|
||||
# # return agent_restore_path
|
||||
|
||||
# # def _setup_eval(self):
|
||||
# # self._can_learn = False
|
||||
# # self._can_evaluate = True
|
||||
# # self._agent.restore(str(self._unpack_saved_agent_into_eval()))
|
||||
|
||||
# # def evaluate(
|
||||
# # self,
|
||||
# # **kwargs,
|
||||
# # ):
|
||||
# # """
|
||||
# # Evaluate the agent.
|
||||
|
||||
# # :param kwargs: Any agent-specific key-word args to be passed.
|
||||
# # """
|
||||
# # time_steps = self._training_config.num_eval_steps
|
||||
# # episodes = self._training_config.num_eval_episodes
|
||||
|
||||
# # self._setup_eval()
|
||||
|
||||
# # self._env: Primaite = Primaite(
|
||||
# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
|
||||
# # )
|
||||
|
||||
# # self._env.set_as_eval()
|
||||
# # self.is_eval = True
|
||||
# # if self._training_config.deterministic:
|
||||
# # deterministic_str = "deterministic"
|
||||
# # else:
|
||||
# # deterministic_str = "non-deterministic"
|
||||
# # _LOGGER.info(
|
||||
# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
# # )
|
||||
# # for episode in range(episodes):
|
||||
# # obs = self._env.reset()
|
||||
# # for step in range(time_steps):
|
||||
# # action = self._agent.compute_single_action(observation=obs, explore=False)
|
||||
|
||||
# # obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
# # self._env.reset()
|
||||
# # self._env.close()
|
||||
# # super().evaluate()
|
||||
# # # Now we're safe to close the learning agent and write the mean rewards per episode for it
|
||||
# # if self._training_config.session_type is not SessionType.TRAIN:
|
||||
# # self._train_agent.stop()
|
||||
# # self._plot_av_reward_per_episode(learning_session=True)
|
||||
# # # Perform a clean-up of the unpacked agent
|
||||
# # if (self.evaluation_path / "agent_restore").exists():
|
||||
# # shutil.rmtree((self.evaluation_path / "agent_restore"))
|
||||
|
||||
# # def _get_latest_checkpoint(self) -> None:
|
||||
# # raise NotImplementedError
|
||||
|
||||
# # @classmethod
|
||||
# # def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
# # """Load an agent from file."""
|
||||
# # raise NotImplementedError
|
||||
|
||||
# # def save(self, overwrite_existing: bool = True) -> None:
|
||||
# # """Save the agent."""
|
||||
# # # Make temp dir to save in isolation
|
||||
# # temp_dir = self.learning_path / str(uuid4())
|
||||
# # temp_dir.mkdir()
|
||||
|
||||
# # # Save the agent to the temp dir
|
||||
# # self._agent.save(str(temp_dir))
|
||||
|
||||
# # # Capture the saved Rllib checkpoint inside the temp directory
|
||||
# # for file in temp_dir.iterdir():
|
||||
# # checkpoint_dir = file
|
||||
# # break
|
||||
|
||||
# # # Zip the folder
|
||||
# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# # # Drop the temp directory
|
||||
# # shutil.rmtree(temp_dir)
|
||||
|
||||
# # def export(self) -> None:
|
||||
# # """Export the agent to transportable file format."""
|
||||
# # raise NotImplementedError
|
||||
@@ -1,206 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_class: Union[PPO, A2C]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}"
|
||||
)
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
legacy_training_config=self.legacy_training_config,
|
||||
legacy_lay_down_config=self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
# check if there is a zip file that needs to be loaded
|
||||
load_file = next(self.session_path.rglob("*.zip"), None)
|
||||
|
||||
if not load_file:
|
||||
# create a new env and agent
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
else:
|
||||
# set env values from session metadata
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# load environment values
|
||||
if self.is_eval:
|
||||
# evaluation always starts at 0
|
||||
self._env.episode_count = 0
|
||||
self._env.total_step_count = 0
|
||||
else:
|
||||
# carry on from previous learning sessions
|
||||
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
||||
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
||||
|
||||
# load the file
|
||||
self._agent = self._agent_class.load(load_file, env=self._env)
|
||||
|
||||
# set agent values
|
||||
self._agent.verbose = self.sb3_output_verbose_level
|
||||
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env._write_av_reward_per_episode() # noqa
|
||||
self.save()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self._env._write_av_reward_per_episode() # noqa
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
@@ -1,59 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Random Agent.
|
||||
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
class DummyAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Dummy Agent.
|
||||
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing ACL agent.
|
||||
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
return nothing_action
|
||||
|
||||
|
||||
class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing Node agent.
|
||||
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
# nothing_action should currently always be 0
|
||||
|
||||
return nothing_action
|
||||
@@ -1,450 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
LinkStatus,
|
||||
NodeHardwareAction,
|
||||
NodePOLType,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
)
|
||||
|
||||
|
||||
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def is_valid_node_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
# print("node property", node_property, "\nnode action", node_action)
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action: List[int]) -> bool:
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
os_states = [SoftwareState(i).name for i in obs[:, 2]]
|
||||
new_obs = [ids, operating_states, os_states]
|
||||
|
||||
for service in range(4, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform observation to readable format.
|
||||
|
||||
example
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
# Convert list of tuples to list of lists
|
||||
new_obs = [list(i) for i in new_obs]
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
:param obs: observation in the 'old' (NodeLinkTable) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:return: reformatted observation
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
|
||||
"""Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
example:
|
||||
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
|
||||
|
||||
new_obs = array([[ 1, 1, 1, 1],
|
||||
[ 2, 1, 1, 1],
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
|
||||
:param obs: observation in the 'new' (MultiDiscrete) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: number of links in the network, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: number of services on the network, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: 2-d BOX observation space, in the same format as NodeLinkTable
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
|
||||
# Add empty links back and add node ID back
|
||||
s = np.zeros(
|
||||
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
|
||||
dtype=np.int64,
|
||||
)
|
||||
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
|
||||
s[:num_nodes, 1:] = reshaped_nodes # put values back in
|
||||
new_obs = s
|
||||
|
||||
# Add links back in
|
||||
links = obs[-num_links:]
|
||||
# Links will be added to the last protocol/service slot but they are not specific to that service
|
||||
new_obs[num_nodes:, -1] = links
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(
|
||||
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
|
||||
) -> str:
|
||||
"""Build a string describing the difference between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
:param obs1: First observation
|
||||
:type obs1: np.ndarray
|
||||
:param obs2: Second observation
|
||||
:type obs2: np.ndarray
|
||||
:param num_nodes: How many nodes are in the network laydown, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: How many links are in the network laydown, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: How many services are configured for this scenario, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: A multi-line string with a human-readable description of the difference.
|
||||
:rtype: str
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
list_of_changes = []
|
||||
for n, row in enumerate(obs1 - obs2):
|
||||
if row.any() != 0:
|
||||
relevant_changes = np.where(row != 0, obs2[n], -1)
|
||||
relevant_changes[0] = obs2[n, 0] # ID is always relevant
|
||||
is_link = relevant_changes[0] > num_nodes
|
||||
desc = _describe_obs_change_helper(relevant_changes, is_link)
|
||||
list_of_changes.append(desc)
|
||||
|
||||
change_string = "\n ".join(list_of_changes)
|
||||
if len(list_of_changes) > 0:
|
||||
change_string = "\n " + change_string
|
||||
return change_string
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str:
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
|
||||
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
|
||||
|
||||
:param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one
|
||||
row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new
|
||||
status where it has changed.
|
||||
:type obs_change: List[int]
|
||||
:param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node.
|
||||
:type is_link: bool
|
||||
:return: A human-readable description of the difference between the two observation rows.
|
||||
:rtype: str
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
if is_link
|
||||
else HardwareState(obs_change[i]).name
|
||||
if i == 1
|
||||
else SoftwareState(obs_change[i]).name
|
||||
for i in index_changed
|
||||
]
|
||||
|
||||
if not is_link:
|
||||
desc = f"ID {obs_change[0]}:"
|
||||
for node_pol_type, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + node_pol_type + " changed to " + state + "."
|
||||
else:
|
||||
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]:
|
||||
"""Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
:param action: Action in 'readable' format
|
||||
:type action: List[Union[str,int]]
|
||||
:return: Action with verbs encoded as ints
|
||||
:rtype: List[int]
|
||||
"""
|
||||
action_node_id = action[0]
|
||||
action_node_property = NodePOLType[action[1]].value
|
||||
|
||||
if action[1] == "OPERATING":
|
||||
property_action = NodeHardwareAction[action[2]].value
|
||||
elif action[1] == "OS" or action[1] == "SERVICE":
|
||||
property_action = NodeSoftwareAction[action[2]].value
|
||||
else:
|
||||
property_action = 0
|
||||
|
||||
action_service_index = action[3]
|
||||
|
||||
new_action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray:
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
:param action: ACL-based action expressed as a list of human-readable ints and strings
|
||||
:type action: List[Union[int,str]]
|
||||
:return: The same action but encoded to contain only integers.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
|
||||
action_permissions = {"DENY": 0, "ALLOW": 1}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == "ANY":
|
||||
new_action[n + 2] = 0
|
||||
|
||||
new_action = np.array(new_action)
|
||||
return new_action
|
||||
|
||||
|
||||
def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str:
|
||||
"""Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
|
||||
:param ip: The IP address of the node whose ID is required
|
||||
:type ip: str
|
||||
:param node_dict: The environment's node registry dictionary
|
||||
:type node_dict: Dict[str,NodeUnion]
|
||||
:return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip`
|
||||
:rtype: str
|
||||
"""
|
||||
for node_key, node_value in node_dict.items():
|
||||
node_ip = node_value.ip_address
|
||||
if node_ip == ip:
|
||||
return node_key
|
||||
|
||||
|
||||
def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int:
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
:param old_action: Action expressed as a list of choices, eg. [1,1,1,0]
|
||||
:type old_action: np.ndarray
|
||||
:param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions.
|
||||
:type action_dict: Dict[int,List]
|
||||
:return: Action key correspoinding to the input `old_action`
|
||||
:rtype: int
|
||||
"""
|
||||
for key, val in action_dict.items():
|
||||
if list(val) == list(old_action):
|
||||
return key
|
||||
# Not all possible actions are included in dict, only valid action are
|
||||
# if action is not in the dict, its an invalid action so return 0
|
||||
return 0
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Objects which are shared between many PrimAITE modules."""
|
||||
@@ -1,8 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
@@ -1,208 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""Node type enumeration."""
|
||||
|
||||
CCTV = 1
|
||||
SWITCH = 2
|
||||
COMPUTER = 3
|
||||
LINK = 4
|
||||
MONITOR = 5
|
||||
PRINTER = 6
|
||||
LOP = 7
|
||||
RTU = 8
|
||||
ACTUATOR = 9
|
||||
SERVER = 10
|
||||
|
||||
|
||||
class Priority(Enum):
|
||||
"""Node priority enumeration."""
|
||||
|
||||
P1 = 1
|
||||
P2 = 2
|
||||
P3 = 3
|
||||
P4 = 4
|
||||
P5 = 5
|
||||
|
||||
|
||||
class HardwareState(Enum):
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
SHUTTING_DOWN = 4
|
||||
BOOTING = 5
|
||||
|
||||
|
||||
class SoftwareState(Enum):
|
||||
"""Software or Service state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
OVERWHELMED = 4
|
||||
|
||||
|
||||
class NodePOLType(Enum):
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
NONE = 0
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
FILE = 4
|
||||
|
||||
|
||||
class NodePOLInitiator(Enum):
|
||||
"""Node Pattern of Life initiator enumeration."""
|
||||
|
||||
DIRECT = 1
|
||||
IER = 2
|
||||
SERVICE = 3
|
||||
|
||||
|
||||
class Protocol(Enum):
|
||||
"""Service protocol enumeration."""
|
||||
|
||||
LDAP = 0
|
||||
FTP = 1
|
||||
HTTPS = 2
|
||||
SMTP = 3
|
||||
RTP = 4
|
||||
IPP = 5
|
||||
TCP = 6
|
||||
NONE = 7
|
||||
|
||||
|
||||
class SessionType(Enum):
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
|
||||
TRAIN = 1
|
||||
"Train an agent"
|
||||
EVAL = 2
|
||||
"Evaluate an agent"
|
||||
TRAIN_EVAL = 3
|
||||
"Train then evaluate an agent"
|
||||
|
||||
|
||||
class AgentFramework(Enum):
|
||||
"""The agent algorithm framework/package."""
|
||||
|
||||
CUSTOM = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
# RLLIB = 2
|
||||
# "Ray RLlib"
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework."""
|
||||
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
"Tensorflow 2.x"
|
||||
TORCH = "torch"
|
||||
"PyTorch"
|
||||
|
||||
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"The Hardcoded agents"
|
||||
DO_NOTHING = 4
|
||||
"The DoNothing agents"
|
||||
RANDOM = 5
|
||||
"The RandomAgent"
|
||||
DUMMY = 6
|
||||
"The DummyAgent"
|
||||
|
||||
|
||||
class HardCodedAgentView(Enum):
|
||||
"""The view the deterministic hard-coded agent has of the environment."""
|
||||
|
||||
BASIC = 1
|
||||
"The current observation space only"
|
||||
FULL = 2
|
||||
"Full environment view with actions taken and reward feedback"
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
|
||||
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
BOX = 0
|
||||
MULTIDISCRETE = 1
|
||||
|
||||
|
||||
class FileSystemState(Enum):
|
||||
"""File System State."""
|
||||
|
||||
GOOD = 1
|
||||
CORRUPT = 2
|
||||
DESTROYED = 3
|
||||
REPAIRING = 4
|
||||
RESTORING = 5
|
||||
|
||||
|
||||
class NodeHardwareAction(Enum):
|
||||
"""Node hardware action."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESET = 3
|
||||
|
||||
|
||||
class NodeSoftwareAction(Enum):
|
||||
"""Node software action."""
|
||||
|
||||
NONE = 0
|
||||
PATCHING = 1
|
||||
|
||||
|
||||
class LinkStatus(Enum):
|
||||
"""Link traffic status."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
|
||||
class SB3OutputVerboseLevel(IntEnum):
|
||||
"""The Stable Baselines3 learn/eval output verbosity level."""
|
||||
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Any firewall rule type."""
|
||||
|
||||
NONE = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
@@ -1,47 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name: str) -> None:
|
||||
"""
|
||||
Initialise a protocol.
|
||||
|
||||
:param _name: The name of the protocol
|
||||
:type _name: str
|
||||
"""
|
||||
self.name: str = _name
|
||||
self.load: int = 0 # bps
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Gets the protocol name.
|
||||
|
||||
Returns:
|
||||
The protocol name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets the protocol load.
|
||||
|
||||
Returns:
|
||||
The protocol load (bps)
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load: int) -> None:
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
|
||||
Args:
|
||||
_load: The load to add
|
||||
"""
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self) -> None:
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
@@ -1,28 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SoftwareState
|
||||
|
||||
|
||||
class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Initialise a service.
|
||||
|
||||
:param name: The service name.
|
||||
:param port: The service port.
|
||||
:param software_state: The service SoftwareState.
|
||||
"""
|
||||
self.name: str = name
|
||||
self.port: str = port
|
||||
self.software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
|
||||
def reduce_patching_count(self) -> None:
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self.software_state = SoftwareState.GOOD
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Configuration parameters for running experiments."""
|
||||
@@ -1,166 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.5
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: SWITCH2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.6
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH3
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.7
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: LINK
|
||||
id: '7'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '4'
|
||||
- item_type: LINK
|
||||
id: '8'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '2'
|
||||
- item_type: LINK
|
||||
id: '9'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '10'
|
||||
name: link4
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: link5
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: link6
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '3'
|
||||
- item_type: GREEN_IER
|
||||
id: '13'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '3'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 50
|
||||
end_step: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '15'
|
||||
start_step: 60
|
||||
end_step: 100
|
||||
load: 1000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '2'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '16'
|
||||
start_step: 80
|
||||
end_step: 80
|
||||
targetNodeId: '2'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: ACL_RULE
|
||||
id: '17'
|
||||
permission: ALLOW
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
@@ -1,366 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: PC3
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.13
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: PC4
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.20.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: IDS
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '7'
|
||||
name: SWITCH2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '8'
|
||||
name: LOP1
|
||||
node_class: SERVICE
|
||||
node_type: LOP
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '9'
|
||||
name: SERVER1
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '10'
|
||||
name: SERVER2
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.20.15
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '13'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '14'
|
||||
name: link4
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '15'
|
||||
name: link5
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '16'
|
||||
name: link6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '8'
|
||||
- item_type: LINK
|
||||
id: '17'
|
||||
name: link7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '18'
|
||||
name: link8
|
||||
bandwidth: 1000000000
|
||||
source: '8'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '19'
|
||||
name: link9
|
||||
bandwidth: 1000000000
|
||||
source: '7'
|
||||
destination: '9'
|
||||
- item_type: LINK
|
||||
id: '20'
|
||||
name: link10
|
||||
bandwidth: 1000000000
|
||||
source: '7'
|
||||
destination: '10'
|
||||
- item_type: GREEN_IER
|
||||
id: '21'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '22'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '23'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '3'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '24'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '4'
|
||||
destination: '10'
|
||||
mission_criticality: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '25'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '26'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '27'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.13
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '28'
|
||||
permission: ALLOW
|
||||
source: 192.168.20.14
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '29'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.14
|
||||
destination: 192.168.10.13
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '30'
|
||||
permission: DENY
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '31'
|
||||
permission: DENY
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '32'
|
||||
permission: DENY
|
||||
source: 192.168.10.13
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: DENY
|
||||
source: 192.168.20.14
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 8
|
||||
- item_type: RED_POL
|
||||
id: '34'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_POL
|
||||
id: '35'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '2'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '36'
|
||||
start_step: 30
|
||||
end_step: 128
|
||||
load: 440000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_IER
|
||||
id: '37'
|
||||
start_step: 30
|
||||
end_step: 128
|
||||
load: 440000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '38'
|
||||
start_step: 30
|
||||
end_step: 30
|
||||
targetNodeId: '9'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
@@ -1,164 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SERVER1
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '5'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '6'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '7'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- item_type: GREEN_IER
|
||||
id: '8'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '4'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '9'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '4'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '10'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '4'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '11'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.2
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '12'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.3
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '13'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.4
|
||||
destination: 192.168.1.3
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '15'
|
||||
start_step: 30
|
||||
end_step: 256
|
||||
load: 10000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '4'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '16'
|
||||
start_step: 40
|
||||
end_step: 40
|
||||
targetNodeId: '4'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
@@ -1,546 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: CLIENT_1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: CLIENT_2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH_1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SECURITY_SUITE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH_2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '7'
|
||||
name: WEB_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '8'
|
||||
name: DATABASE_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '9'
|
||||
name: BACKUP_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.16
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: LINK_2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: LINK_3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- item_type: LINK
|
||||
id: '13'
|
||||
name: LINK_4
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '14'
|
||||
name: LINK_5
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '15'
|
||||
name: LINK_6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '16'
|
||||
name: LINK_7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '17'
|
||||
name: LINK_8
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '8'
|
||||
- item_type: LINK
|
||||
id: '18'
|
||||
name: LINK_9
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '9'
|
||||
- item_type: GREEN_IER
|
||||
id: '19'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '20'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '1'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '21'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '22'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '23'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '7'
|
||||
destination: '8'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '24'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 100000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '8'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '25'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '26'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '27'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '7'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '28'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '29'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '8'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '30'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '8'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '31'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '9'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '32'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 8
|
||||
- item_type: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 9
|
||||
- item_type: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 10
|
||||
- item_type: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 11
|
||||
- item_type: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 12
|
||||
- item_type: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 13
|
||||
- item_type: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 14
|
||||
- item_type: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.16
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 15
|
||||
- item_type: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 16
|
||||
- item_type: RED_POL
|
||||
id: '50'
|
||||
start_step: 50
|
||||
end_step: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '51'
|
||||
start_step: 75
|
||||
end_step: 105
|
||||
load: 10000
|
||||
protocol: UDP
|
||||
port: '53'
|
||||
source: '1'
|
||||
destination: '8'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '52'
|
||||
start_step: 100
|
||||
end_step: 100
|
||||
targetNodeId: '8'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_POL
|
||||
id: '53'
|
||||
start_step: 105
|
||||
end_step: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: FILE
|
||||
protocol: NA
|
||||
state: CORRUPT
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- item_type: RED_POL
|
||||
id: '54'
|
||||
start_step: 105
|
||||
end_step: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP_SQL
|
||||
state: COMPROMISED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- item_type: RED_POL
|
||||
id: '55'
|
||||
start_step: 125
|
||||
end_step: 125
|
||||
targetNodeId: '7'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
@@ -1,168 +0,0 @@
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 1
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -1,141 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a legacy lay down config to the new format.
|
||||
|
||||
:param legacy_config: A legacy lay down config.
|
||||
"""
|
||||
field_conversion_map = {
|
||||
"itemType": "item_type",
|
||||
"portsList": "ports_list",
|
||||
"serviceList": "service_list",
|
||||
"baseType": "node_class",
|
||||
"nodeType": "node_type",
|
||||
"hardwareState": "hardware_state",
|
||||
"softwareState": "software_state",
|
||||
"startStep": "start_step",
|
||||
"endStep": "end_step",
|
||||
"fileSystemState": "file_system_state",
|
||||
"ipAddress": "ip_address",
|
||||
"missionCriticality": "mission_criticality",
|
||||
}
|
||||
new_config = []
|
||||
for item in legacy_config:
|
||||
if "itemType" in item:
|
||||
if item["itemType"] in ["ACTIONS", "STEPS"]:
|
||||
continue
|
||||
new_dict = {}
|
||||
for key in item.keys():
|
||||
conversion_key = field_conversion_map.get(key)
|
||||
if key == "id" and "itemType" in item:
|
||||
if item["itemType"] == "NODE":
|
||||
conversion_key = "node_id"
|
||||
if conversion_key:
|
||||
new_dict[conversion_key] = item[key]
|
||||
else:
|
||||
new_dict[key] = item[key]
|
||||
new_config.append(new_dict)
|
||||
return new_config
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise False.
|
||||
:return: The lay down config as a dict.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading lay down config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_lay_down_config(config)
|
||||
except KeyError:
|
||||
msg = (
|
||||
f"Failed to convert lay down config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
return config
|
||||
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def ddos_basic_two_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def dos_very_basic_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
@@ -1,438 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
RulePermissionType,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training"
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The AgentFramework"
|
||||
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework"
|
||||
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
|
||||
"The AgentIdentifier"
|
||||
|
||||
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
|
||||
"The view the deterministic hard-coded agent has of the environment"
|
||||
|
||||
random_red_agent: bool = False
|
||||
"Creates Random Red Agent Attacks"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use"
|
||||
|
||||
num_train_episodes: int = 10
|
||||
"The number of episodes to train over during an training session"
|
||||
|
||||
num_train_steps: int = 256
|
||||
"The number of steps in an episode during an training session"
|
||||
|
||||
num_eval_episodes: int = 1
|
||||
"The number of episodes to train over during an evaluation session"
|
||||
|
||||
num_eval_steps: int = 256
|
||||
"The number of steps in an episode during an evaluation session"
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
|
||||
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
|
||||
"The observation space config dict"
|
||||
|
||||
time_delay: int = 10
|
||||
"The delay between steps (ms). Applies to generic agents only"
|
||||
|
||||
# file
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: bool = False
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
"File path and file name of agent if you're loading one in"
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int = 1000000000
|
||||
"The high value for the observation space"
|
||||
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: float = 0
|
||||
|
||||
# Node Hardware State
|
||||
off_should_be_on: float = -0.001
|
||||
off_should_be_resetting: float = -0.0005
|
||||
on_should_be_off: float = -0.0002
|
||||
on_should_be_resetting: float = -0.0005
|
||||
resetting_should_be_on: float = -0.0005
|
||||
resetting_should_be_off: float = -0.0002
|
||||
resetting: float = -0.0003
|
||||
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: float = 0.0002
|
||||
good_should_be_compromised: float = 0.0005
|
||||
good_should_be_overwhelmed: float = 0.0005
|
||||
patching_should_be_good: float = -0.0005
|
||||
patching_should_be_compromised: float = 0.0002
|
||||
patching_should_be_overwhelmed: float = 0.0002
|
||||
patching: float = -0.0003
|
||||
compromised_should_be_good: float = -0.002
|
||||
compromised_should_be_patching: float = -0.002
|
||||
compromised_should_be_overwhelmed: float = -0.002
|
||||
compromised: float = -0.002
|
||||
overwhelmed_should_be_good: float = -0.002
|
||||
overwhelmed_should_be_patching: float = -0.002
|
||||
overwhelmed_should_be_compromised: float = -0.002
|
||||
overwhelmed: float = -0.002
|
||||
|
||||
# Node File System State
|
||||
good_should_be_repairing: float = 0.0002
|
||||
good_should_be_restoring: float = 0.0002
|
||||
good_should_be_corrupt: float = 0.0005
|
||||
good_should_be_destroyed: float = 0.001
|
||||
repairing_should_be_good: float = -0.0005
|
||||
repairing_should_be_restoring: float = 0.0002
|
||||
repairing_should_be_corrupt: float = 0.0002
|
||||
repairing_should_be_destroyed: float = 0.0000
|
||||
repairing: float = -0.0003
|
||||
restoring_should_be_good: float = -0.001
|
||||
restoring_should_be_repairing: float = -0.0002
|
||||
restoring_should_be_corrupt: float = 0.0001
|
||||
restoring_should_be_destroyed: float = 0.0002
|
||||
restoring: float = -0.0006
|
||||
corrupt_should_be_good: float = -0.001
|
||||
corrupt_should_be_repairing: float = -0.001
|
||||
corrupt_should_be_restoring: float = -0.001
|
||||
corrupt_should_be_destroyed: float = 0.0002
|
||||
corrupt: float = -0.001
|
||||
destroyed_should_be_good: float = -0.002
|
||||
destroyed_should_be_repairing: float = -0.002
|
||||
destroyed_should_be_restoring: float = -0.002
|
||||
destroyed_should_be_corrupt: float = -0.002
|
||||
destroyed: float = -0.002
|
||||
scanning: float = -0.0002
|
||||
|
||||
# IER status
|
||||
red_ier_running: float = -0.0005
|
||||
green_ier_blocked: float = -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
"The time taken to patch the OS"
|
||||
|
||||
node_reset_duration: int = 5
|
||||
"The time taken to reset a node (hardware)"
|
||||
|
||||
node_booting_duration: int = 3
|
||||
"The Time taken to turn on the node"
|
||||
|
||||
node_shutdown_duration: int = 2
|
||||
"The time taken to turn off the node"
|
||||
|
||||
service_patching_duration: int = 5
|
||||
"The time taken to patch a service"
|
||||
|
||||
file_system_repairing_limit: int = 5
|
||||
"The time take to repair the file system"
|
||||
|
||||
file_system_restoring_limit: int = 5
|
||||
"The time take to restore the file system"
|
||||
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system"
|
||||
|
||||
deterministic: bool = False
|
||||
"If true, the training will be deterministic"
|
||||
|
||||
seed: Optional[int] = None
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
"""
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"agent_identifier": AgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
:param json_serializable: If True, Enums are converted to their
|
||||
string name.
|
||||
:return: The ``TrainingConfig`` as a dict.
|
||||
"""
|
||||
data = self.__dict__
|
||||
if json_serializable:
|
||||
data["agent_framework"] = self.agent_framework.name
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.name
|
||||
data["agent_identifier"] = self.agent_identifier.name
|
||||
data["action_type"] = self.action_type.name
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
data["implicit_acl_rule"] = self.implicit_acl_rule.name
|
||||
|
||||
return data
|
||||
|
||||
def __str__(self) -> str:
|
||||
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
|
||||
tc = f"{self.agent_framework}, "
|
||||
# if self.agent_framework is AgentFramework.RLLIB:
|
||||
# tc += f"{self.deep_learning_framework}, "
|
||||
tc += f"{self.agent_identifier}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
tc += f"{self.action_type}, "
|
||||
tc += f"observation_space={obs_str}, "
|
||||
if self.session_type is SessionType.TRAIN:
|
||||
tc += f"{self.num_train_episodes} episodes @ "
|
||||
tc += f"{self.num_train_steps} steps"
|
||||
elif self.session_type is SessionType.EVAL:
|
||||
tc += f"{self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
else:
|
||||
tc += f"Training: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
return tc
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:return: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
:raises TypeError: When the TrainingConfig object cannot be created
|
||||
using the values from the config file read from ``file_path``.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading training config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_training_config_dict(config)
|
||||
|
||||
except KeyError as e:
|
||||
msg = (
|
||||
f"Failed to convert training config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise e
|
||||
try:
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
|
||||
_LOGGER.critical(msg, exc_info=True)
|
||||
raise e
|
||||
msg = f"Cannot load the training config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_train_steps: int = 256,
|
||||
num_eval_steps: int = 256,
|
||||
num_train_episodes: int = 10,
|
||||
num_eval_episodes: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_train_steps: The number of train steps to set as legacy training configs
|
||||
don't have num_train_steps values.
|
||||
:param num_eval_steps: The number of eval steps to set as legacy training configs
|
||||
don't have num_eval_steps values.
|
||||
:param num_train_episodes: The number of train episodes to set as legacy training configs
|
||||
don't have num_train_episodes values.
|
||||
:param num_eval_episodes: The number of eval episodes to set as legacy training configs
|
||||
don't have num_eval_episodes values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"agent_framework": agent_framework.name,
|
||||
"agent_identifier": agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_train_steps": num_train_steps,
|
||||
"num_eval_steps": num_eval_steps,
|
||||
"num_train_episodes": num_train_episodes,
|
||||
"num_eval_episodes": num_eval_episodes,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
|
||||
}
|
||||
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
|
||||
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
config_dict[new_key] = value
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
:param legacy_key: A legacy training config key.
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_train_episodes",
|
||||
"numSteps": "num_train_steps",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
"sessionType": "session_type",
|
||||
"loadAgent": "load_agent",
|
||||
"agentLoadFile": "agent_load_file",
|
||||
"observationSpaceHighValue": "observation_space_high_value",
|
||||
"allOk": "all_ok",
|
||||
"offShouldBeOn": "off_should_be_on",
|
||||
"offShouldBeResetting": "off_should_be_resetting",
|
||||
"onShouldBeOff": "on_should_be_off",
|
||||
"onShouldBeResetting": "on_should_be_resetting",
|
||||
"resettingShouldBeOn": "resetting_should_be_on",
|
||||
"resettingShouldBeOff": "resetting_should_be_off",
|
||||
"resetting": "resetting",
|
||||
"goodShouldBePatching": "good_should_be_patching",
|
||||
"goodShouldBeCompromised": "good_should_be_compromised",
|
||||
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
|
||||
"patchingShouldBeGood": "patching_should_be_good",
|
||||
"patchingShouldBeCompromised": "patching_should_be_compromised",
|
||||
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
|
||||
"patching": "patching",
|
||||
"compromisedShouldBeGood": "compromised_should_be_good",
|
||||
"compromisedShouldBePatching": "compromised_should_be_patching",
|
||||
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
|
||||
"compromised": "compromised",
|
||||
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
|
||||
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
|
||||
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
|
||||
"overwhelmed": "overwhelmed",
|
||||
"goodShouldBeRepairing": "good_should_be_repairing",
|
||||
"goodShouldBeRestoring": "good_should_be_restoring",
|
||||
"goodShouldBeCorrupt": "good_should_be_corrupt",
|
||||
"goodShouldBeDestroyed": "good_should_be_destroyed",
|
||||
"repairingShouldBeGood": "repairing_should_be_good",
|
||||
"repairingShouldBeRestoring": "repairing_should_be_restoring",
|
||||
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
|
||||
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
|
||||
"repairing": "repairing",
|
||||
"restoringShouldBeGood": "restoring_should_be_good",
|
||||
"restoringShouldBeRepairing": "restoring_should_be_repairing",
|
||||
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
|
||||
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
|
||||
"restoring": "restoring",
|
||||
"corruptShouldBeGood": "corrupt_should_be_good",
|
||||
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
|
||||
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
|
||||
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
|
||||
"corrupt": "corrupt",
|
||||
"destroyedShouldBeGood": "destroyed_should_be_good",
|
||||
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
|
||||
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
|
||||
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
|
||||
"destroyed": "destroyed",
|
||||
"scanning": "scanning",
|
||||
"redIerRunning": "red_ier_running",
|
||||
"greenIerBlocked": "green_ier_blocked",
|
||||
"osPatchingDuration": "os_patching_duration",
|
||||
"nodeResetDuration": "node_reset_duration",
|
||||
"nodeBootingDuration": "node_booting_duration",
|
||||
"nodeShutdownDuration": "node_shutdown_duration",
|
||||
"servicePatchingDuration": "service_patching_duration",
|
||||
"fileSystemRepairingLimit": "file_system_repairing_limit",
|
||||
"fileSystemRestoringLimit": "file_system_restoring_limit",
|
||||
"fileSystemScanningLimit": "file_system_scanning_limit",
|
||||
}
|
||||
return key_mapping[legacy_key]
|
||||
@@ -1,15 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Utility to generate plots of sessions metrics after PrimAITE."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlotlyTemplate(Enum):
|
||||
"""The built-in plotly templates."""
|
||||
|
||||
PLOTLY = "plotly"
|
||||
PLOTLY_WHITE = "plotly_white"
|
||||
PLOTLY_DARK = "plotly_dark"
|
||||
GGPLOT2 = "ggplot2"
|
||||
SEABORN = "seaborn"
|
||||
SIMPLE_WHITE = "simple_white"
|
||||
NONE = "none"
|
||||
@@ -1,73 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
|
||||
def get_plotly_config() -> Dict:
|
||||
"""Get the plotly config from primaite_config.yaml."""
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["session"]["outputs"]["plots"]
|
||||
|
||||
|
||||
def plot_av_reward_per_episode(
|
||||
av_reward_per_episode_csv: Union[str, Path],
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
:param title: The plot title. This is optional.
|
||||
:param subtitle: The plot subtitle. This is optional.
|
||||
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
|
||||
"""
|
||||
df = pl.read_csv(av_reward_per_episode_csv)
|
||||
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["Episode"],
|
||||
y=df["Average Reward"],
|
||||
mode="lines",
|
||||
name="Mean Reward per Episode",
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
"rangeslider": {"visible": config["range_slider"]},
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""
|
||||
@@ -1,735 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
# This dependency is only needed for type hints,
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise observation component.
|
||||
|
||||
:param env: Primaite training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
_LOGGER.info(f"Initialising {self} observation component")
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
self.structure: List[str]
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""
|
||||
Table with nodes and links as rows and hardware/software status as cols.
|
||||
|
||||
This will create the observation space formatted as a table of integers.
|
||||
There is one row per node, followed by one row per link.
|
||||
The number of columns is 4 plus one per service. They are:
|
||||
|
||||
* node/link ID
|
||||
* node hardware status / 0 for links
|
||||
* node operating system status (if active/service) / 0 for links
|
||||
* node file system status (active/service only) / 0 for links
|
||||
* node service1 status / traffic load from that service for links
|
||||
* node service2 status / traffic load from that service for links
|
||||
* ...
|
||||
* node serviceN status / traffic load from that service for links
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(12, 7)``
|
||||
"""
|
||||
|
||||
_FIXED_PARAMETERS: int = 4
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
num_items = self.env.num_links + self.env.num_nodes
|
||||
num_columns = self.env.num_services + self._FIXED_PARAMETERS
|
||||
observation_shape = (num_items, num_columns)
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.Box(
|
||||
low=0,
|
||||
high=self._MAX_VAL,
|
||||
shape=observation_shape,
|
||||
dtype=self._DATA_TYPE,
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||
"""
|
||||
item_index = 0
|
||||
nodes = self.env.nodes
|
||||
links = self.env.links
|
||||
# Do nodes first
|
||||
for _, node in nodes.items():
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][service_index] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.env.services_list:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for _, link in links.items():
|
||||
self.current_observation[item_index][0] = int(link.get_id())
|
||||
self.current_observation[item_index][1] = 0
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
|
||||
structure = []
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
node_id = node.node_id
|
||||
node_labels = [
|
||||
f"node_{node_id}_id",
|
||||
f"node_{node_id}_hardware_status",
|
||||
f"node_{node_id}_os_status",
|
||||
f"node_{node_id}_fs_status",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||
|
||||
structure.extend(node_labels)
|
||||
|
||||
for i, link in enumerate(links):
|
||||
link_id = link.id
|
||||
link_labels = [
|
||||
f"link_{link_id}_id",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""
|
||||
Flat list of nodes' hardware, OS, file system, and service states.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
Each node has 3 elements plus 1 per service. It will have the following structure:
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
node1 hardware state,
|
||||
node1 OS state,
|
||||
node1 file system state,
|
||||
node1 service1 state,
|
||||
node1 service2 state,
|
||||
node1 serviceN state (one for each service),
|
||||
node2 hardware state,
|
||||
node2 OS state,
|
||||
node2 file system state,
|
||||
node2 service1 state,
|
||||
node2 service2 state,
|
||||
node2 serviceN state (one for each service),
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeStatuses observation component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
node_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
||||
node_shape = node_shape + services_shape
|
||||
|
||||
shape = node_shape * self.env.num_nodes
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||
"""
|
||||
obs = []
|
||||
for _, node in self.env.nodes.items():
|
||||
hardware_state = node.hardware_state.value
|
||||
software_state = 0
|
||||
file_system_state = 0
|
||||
service_states = [0] * self.env.num_services
|
||||
|
||||
if isinstance(node, ActiveNode):
|
||||
software_state = node.software_state.value
|
||||
file_system_state = node.file_system_state_observed.value
|
||||
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*service_states,
|
||||
]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
for state in HardwareState:
|
||||
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_software_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||
for state in FileSystemState:
|
||||
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
|
||||
return structure
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""
|
||||
Flat list of traffic levels encoded into banded categories.
|
||||
|
||||
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||
|
||||
* 0 = No traffic (0% of bandwidth)
|
||||
* 1 = No traffic (0%-33% of bandwidth)
|
||||
* 2 = No traffic (33%-66% of bandwidth)
|
||||
* 3 = No traffic (66%-100% of bandwidth)
|
||||
* 4 = No traffic (100% of bandwidth)
|
||||
|
||||
.. note::
|
||||
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
|
||||
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
|
||||
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a LinkTrafficLevels observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
|
||||
defaults to False
|
||||
:type combine_service_traffic: bool, optional
|
||||
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical
|
||||
value, defaults to 5
|
||||
:type quantisation_levels: int, optional
|
||||
"""
|
||||
if quantisation_levels < 3:
|
||||
_msg = (
|
||||
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
||||
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
||||
f"Resetting to default value (5)"
|
||||
)
|
||||
_LOGGER.warning(_msg)
|
||||
quantisation_levels = 5
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
self._combine_service_traffic: bool = combine_service_traffic
|
||||
self._quantisation_levels: int = quantisation_levels
|
||||
self._entries_per_link: int = 1
|
||||
|
||||
if not self._combine_service_traffic:
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||
"""
|
||||
obs = []
|
||||
for _, link in self.env.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
traffic_level = 0
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
link_id = link.id
|
||||
if self._combine_service_traffic:
|
||||
protocols = ["overall"]
|
||||
else:
|
||||
protocols = [protocol.name for protocol in link.protocol_list]
|
||||
|
||||
for p in protocols:
|
||||
for i in range(self._quantisation_levels):
|
||||
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||
return structure
|
||||
|
||||
|
||||
class AccessControlList(AbstractObservationComponent):
|
||||
"""Flat list of all the Access Control Rules in the Access Control List.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
acl_rule1 permission,
|
||||
acl_rule1 source_ip,
|
||||
acl_rule1 dest_ip,
|
||||
acl_rule1 protocol,
|
||||
acl_rule1 port,
|
||||
acl_rule1 position,
|
||||
acl_rule2 permission,
|
||||
acl_rule2 source_ip,
|
||||
acl_rule2 dest_ip,
|
||||
acl_rule2 protocol,
|
||||
acl_rule2 port,
|
||||
acl_rule2 position,
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
source_ip = acl_rule.source_ip
|
||||
dest_ip = acl_rule.dest_ip
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
acl_rule_id = self.env.acl.acl.index(acl_rule)
|
||||
|
||||
for permission in RulePermissionType:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
|
||||
for service in self.env.services_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
|
||||
for port in self.env.ports_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
|
||||
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
|
||||
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
||||
further parameters to make them more flexible.
|
||||
"""
|
||||
|
||||
_REGISTRY: Final[Dict[str, type]] = {
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the observation handler."""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
# internal the observation space (unflattened version of space if flatten=True)
|
||||
self._space: spaces.Space
|
||||
# flattened version of the observation space
|
||||
self._flat_space: spaces.Space
|
||||
|
||||
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
if len(current_obs) == 1:
|
||||
self._observation = current_obs[0]
|
||||
else:
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
:param obs_component: The component to add.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
:param obs_component: Which component to remove. It must exist within this object's
|
||||
``registered_obs_components`` attribute.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self) -> None:
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# if there are multiple components, build a composite tuple space
|
||||
if len(component_spaces) == 1:
|
||||
self._space = component_spaces[0]
|
||||
else:
|
||||
self._space = spaces.Tuple(component_spaces)
|
||||
if len(component_spaces) > 0:
|
||||
self._flat_space = spaces.flatten_space(self._space)
|
||||
else:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
The expected format for the config dictionary is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = {
|
||||
components: [
|
||||
{
|
||||
"name": "<COMPONENT1_NAME>"
|
||||
},
|
||||
{
|
||||
"name": "<COMPONENT2_NAME>"
|
||||
"options": {"opt1": val1, "opt2": val2}
|
||||
},
|
||||
{
|
||||
...
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
:return: Observation handler
|
||||
:rtype: primaite.environment.observations.ObservationsHandler
|
||||
"""
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
comp_class = cls._REGISTRY[comp_type]
|
||||
|
||||
# Create the component with options from the YAML
|
||||
options = component_cfg.get("options") or {}
|
||||
component = comp_class(env, **options)
|
||||
|
||||
handler.register(component)
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self) -> List[str]:
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||
|
||||
labels = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
labels.extend(obs_comp.structure)
|
||||
|
||||
return labels
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,386 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements reward function."""
|
||||
from logging import Logger
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes: Dict[str, NodeUnion],
|
||||
final_nodes: Dict[str, NodeUnion],
|
||||
reference_nodes: Dict[str, NodeUnion],
|
||||
green_iers: Dict[str, "IER"],
|
||||
green_iers_reference: Dict[str, "IER"],
|
||||
red_iers: Dict[str, "IER"],
|
||||
step_count: int,
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
final_nodes: The nodes after red and blue agents take effect
|
||||
reference_nodes: The nodes if there had been no red or blue effect
|
||||
green_iers: The green IERs (should be running)
|
||||
red_iers: Should be stopeed (ideally) by the blue agent
|
||||
step_count: current step
|
||||
config_values: Config values
|
||||
"""
|
||||
reward_value: float = 0.0
|
||||
|
||||
# For each node, compare hardware state, SoftwareState, service states
|
||||
for node_key, final_node in final_nodes.items():
|
||||
initial_node = initial_nodes[node_key]
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Hardware State
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Service State
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if ier_value.get_is_running():
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_operating_state = final_node.hardware_state
|
||||
reference_node_operating_state = reference_node.hardware_state
|
||||
|
||||
if final_node_operating_state == reference_node_operating_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final (current) state of node (i.e. at every step)
|
||||
if reference_node_operating_state == HardwareState.ON:
|
||||
if final_node_operating_state == HardwareState.OFF:
|
||||
score += config_values.off_should_be_on
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting_should_be_on
|
||||
else:
|
||||
pass
|
||||
elif reference_node_operating_state == HardwareState.OFF:
|
||||
if final_node_operating_state == HardwareState.ON:
|
||||
score += config_values.on_should_be_off
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting_should_be_off
|
||||
else:
|
||||
pass
|
||||
elif reference_node_operating_state == HardwareState.RESETTING:
|
||||
if final_node_operating_state == HardwareState.ON:
|
||||
score += config_values.on_should_be_resetting
|
||||
elif final_node_operating_state == HardwareState.OFF:
|
||||
score += config_values.off_should_be_resetting
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_os_state = final_node.software_state
|
||||
reference_node_os_state = reference_node.software_state
|
||||
|
||||
if final_node_os_state == reference_node_os_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final (current) state of node (i.e. at every step)
|
||||
if reference_node_os_state == SoftwareState.GOOD:
|
||||
if final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_node_os_state == SoftwareState.PATCHING:
|
||||
if final_node_os_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif reference_node_os_state == SoftwareState.COMPROMISED:
|
||||
if final_node_os_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(
|
||||
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_services: Dict[str, Service] = final_node.services
|
||||
reference_node_services: Dict[str, Service] = reference_node.services
|
||||
|
||||
for service_key, final_service in final_node_services.items():
|
||||
reference_service = reference_node_services[service_key]
|
||||
final_service = final_node_services[service_key]
|
||||
|
||||
if final_service.software_state == reference_service.software_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final state of node (i.e. at every step)
|
||||
if reference_service.software_state == SoftwareState.GOOD:
|
||||
if final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.PATCHING:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.COMPROMISED:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_compromised
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.OVERWHELMED:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_file_system_state = final_node.file_system_state_actual
|
||||
reference_node_file_system_state = reference_node.file_system_state_actual
|
||||
|
||||
final_node_scanning_state = final_node.file_system_scanning
|
||||
reference_node_scanning_state = reference_node.file_system_scanning
|
||||
|
||||
# File System State
|
||||
if final_node_file_system_state == reference_node_file_system_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final state of node (i.e. at every step)
|
||||
if reference_node_file_system_state == FileSystemState.GOOD:
|
||||
if final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.REPAIRING:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.RESTORING:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.CORRUPT:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.DESTROYED:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Scanning State
|
||||
if final_node_scanning_state == reference_node_scanning_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# We're scanning the file system which incurs a penalty (as it slows down systems)
|
||||
score += config_values.scanning
|
||||
|
||||
return score
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Network connections between nodes in the simulation."""
|
||||
@@ -1,114 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The link class."""
|
||||
from typing import List
|
||||
|
||||
from primaite.common.protocol import Protocol
|
||||
|
||||
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
|
||||
"""
|
||||
Initialise a Link within the simulated network.
|
||||
|
||||
:param _id: The IER id
|
||||
:param _bandwidth: The bandwidth of the link (bps)
|
||||
:param _source_node_name: The name of the source node
|
||||
:param _dest_node_name: The name of the destination node
|
||||
:param _protocols: The protocols to add to the link
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.bandwidth: int = _bandwidth
|
||||
self.source_node_name: str = _source_node_name
|
||||
self.dest_node_name: str = _dest_node_name
|
||||
self.protocol_list: List[Protocol] = []
|
||||
|
||||
# Add the default protocols
|
||||
for protocol_name in _services:
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol: str) -> None:
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
"""
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self) -> str:
|
||||
"""
|
||||
Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
"""
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self) -> str:
|
||||
"""
|
||||
Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
"""
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self) -> int:
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
"""
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self) -> List[Protocol]:
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
"""
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self) -> int:
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
"""
|
||||
total_load = 0
|
||||
for protocol in self.protocol_list:
|
||||
total_load += protocol.get_load()
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol: str, _load: int) -> None:
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
_load: The amount to load (bps)
|
||||
"""
|
||||
for protocol in self.protocol_list:
|
||||
if protocol.get_name() == _protocol:
|
||||
protocol.add_load(_load)
|
||||
else:
|
||||
pass
|
||||
|
||||
def clear_traffic(self) -> None:
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Nodes represent network hosts in the simulation."""
|
||||
@@ -1,208 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActiveNode(Node):
|
||||
"""Active Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an active node.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
:param node_type: The node type (enum)
|
||||
:param priority: The node priority (enum)
|
||||
:param hardware_state: The node Hardware State
|
||||
:param ip_address: The node IP address
|
||||
:param software_state: The node Software State
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
self.ip_address: str = ip_address
|
||||
# Related to Software
|
||||
self._software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
# Related to File System
|
||||
self.file_system_state_actual: FileSystemState = file_system_state
|
||||
self.file_system_state_observed: FileSystemState = file_system_state
|
||||
self.file_system_scanning: bool = False
|
||||
self.file_system_scanning_count: int = 0
|
||||
self.file_system_action_count: int = 0
|
||||
|
||||
@property
|
||||
def software_state(self) -> SoftwareState:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:return: The software_state.
|
||||
"""
|
||||
return self._software_state
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:param software_state: Software State.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be "
|
||||
f"changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
Args:
|
||||
software_state: Software State
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
if self._software_state != SoftwareState.COMPROMISED:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be changed."
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def update_os_patching_status(self) -> None:
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so File System State "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
if (
|
||||
self.file_system_state_actual != FileSystemState.CORRUPT
|
||||
and self.file_system_state_actual != FileSystemState.DESTROYED
|
||||
):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so File System State (if not "
|
||||
f"compromised) cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def start_file_system_scan(self) -> None:
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self) -> None:
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
self.file_system_scanning_count -= 1
|
||||
|
||||
# Reparing / Restoring updates
|
||||
if self.file_system_action_count <= 0:
|
||||
self.file_system_action_count = 0
|
||||
if (
|
||||
self.file_system_state_actual == FileSystemState.REPAIRING
|
||||
or self.file_system_state_actual == FileSystemState.RESTORING
|
||||
):
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
# Scanning updates
|
||||
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
|
||||
self.file_system_state_observed = self.file_system_state_actual
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the reset count & makes software and file state to GOOD."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting software and file state to GOOD."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
@@ -1,79 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The base Node class."""
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
|
||||
class Node:
|
||||
"""Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a node.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
:param node_type: The type of the node.
|
||||
:param priority: The priority of the node.
|
||||
:param hardware_state: The state of the node.
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
self.node_id: Final[str] = node_id
|
||||
self.name: Final[str] = name
|
||||
self.node_type: Final[NodeType] = node_type
|
||||
self.priority = priority
|
||||
self.hardware_state: HardwareState = hardware_state
|
||||
self.resetting_count: int = 0
|
||||
self.config_values: TrainingConfig = config_values
|
||||
self.booting_count: int = 0
|
||||
self.shutting_down_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def turn_on(self) -> None:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self) -> None:
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
self.shutting_down_count = self.config_values.node_shutdown_duration
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.hardware_state = HardwareState.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self) -> None:
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
self.shutting_down_count = 0
|
||||
self.hardware_state = HardwareState.OFF
|
||||
@@ -1,94 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_node_id: str,
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
|
||||
:param _id: The node state instruction id
|
||||
:param _start_step: The start step of the instruction
|
||||
:param _end_step: The end step of the instruction
|
||||
:param _node_id: The id of the associated node
|
||||
:param _node_pol_type: The pattern of life type
|
||||
:param _service_name: The service name
|
||||
:param _state: The state (node or service)
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type: "NodePOLType" = _node_pol_type
|
||||
self.service_name: str = _service_name # Not used when not a service instruction
|
||||
# TODO: confirm type of state
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self) -> "NodePOLType":
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
return self.state
|
||||
@@ -1,143 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionRed:
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_target_node_id: str,
|
||||
_pol_initiator: "NodePOLInitiator",
|
||||
_pol_type: NodePOLType,
|
||||
pol_protocol: str,
|
||||
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
_pol_source_node_id: str,
|
||||
_pol_source_node_service: str,
|
||||
_pol_source_node_service_state: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction for the red agent.
|
||||
|
||||
:param _id: The node state instruction id
|
||||
:param _start_step: The start step of the instruction
|
||||
:param _end_step: The end step of the instruction
|
||||
:param _target_node_id: The id of the associated node
|
||||
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
|
||||
:param _pol_type: The pattern of life type
|
||||
:param pol_protocol: The pattern of life protocol/service affected
|
||||
:param _pol_state: The state (node or service)
|
||||
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.target_node_id: str = _target_node_id
|
||||
self.initiator: "NodePOLInitiator" = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name: str = pol_protocol # Not used when not a service instruction
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
|
||||
self.source_node_id: str = _pol_source_node_id
|
||||
self.source_node_service: str = _pol_source_node_service
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self) -> "NodePOLInitiator":
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
"""
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self) -> NodePOLType:
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self) -> str:
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
"""
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self) -> str:
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
"""
|
||||
return self.source_node_service_state
|
||||
@@ -1,42 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
|
||||
class PassiveNode(Node):
|
||||
"""The Passive Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a passive node.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
:param node_type: The type of the node.
|
||||
:param priority: The priority of the node.
|
||||
:param hardware_state: The state of the node.
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
# Pass through to Super for now
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
"""
|
||||
Gets the node IP address as an empty string.
|
||||
|
||||
No concept of IP address for passive nodes for now.
|
||||
|
||||
:return: The node IP address.
|
||||
"""
|
||||
return ""
|
||||
@@ -1,190 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceNode(ActiveNode):
|
||||
"""ServiceNode class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a Service Node.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
:param node_type: The node type (enum)
|
||||
:param priority: The node priority (enum)
|
||||
:param hardware_state: The node Hardware State
|
||||
:param ip_address: The node IP address
|
||||
:param software_state: The node Software State
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(
|
||||
node_id,
|
||||
name,
|
||||
node_type,
|
||||
priority,
|
||||
hardware_state,
|
||||
ip_address,
|
||||
software_state,
|
||||
file_system_state,
|
||||
config_values,
|
||||
)
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service) -> None:
|
||||
"""
|
||||
Adds a service to the node.
|
||||
|
||||
:param service: The service to add
|
||||
"""
|
||||
self.services[service.name] = service
|
||||
|
||||
def has_service(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is on a node.
|
||||
|
||||
:param protocol_name: The service (protocol)e.
|
||||
:return: True if service (protocol) is on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def service_running(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
if service_value.software_state != SoftwareState.PATCHING:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
if service_value.software_state == SoftwareState.OVERWHELMED:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
# Can't set to compromised if you're in a patching state
|
||||
if (
|
||||
software_state == SoftwareState.COMPROMISED
|
||||
and service_value.software_state != SoftwareState.PATCHING
|
||||
) or software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.services[<key>]:{protocol_name}, "
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
Done if the software_state is not "compromised".
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
if service_value.software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.services[<key>]:{protocol_name}, "
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name: str) -> SoftwareState:
|
||||
"""
|
||||
Gets the state of a service.
|
||||
|
||||
:return: The software_state of the service.
|
||||
"""
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
return service_value.software_state
|
||||
|
||||
def update_services_patching_status(self) -> None:
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.software_state == SoftwareState.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Pattern of Life- Represents the actions of users on the network."""
|
||||
@@ -1,264 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_iers(
|
||||
network: MultiGraph,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
links: Dict[str, Link],
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
nodes: The nodes within the environment
|
||||
links: The links within the environment
|
||||
iers: The IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying IERs")
|
||||
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
protocol = ier_value.get_protocol()
|
||||
port = ier_value.get_port()
|
||||
load = ier_value.get_load()
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
dest_valid = True
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
source_node.hardware_state == HardwareState.ON
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
source_valid = False
|
||||
elif source_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
# TO DO
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if (
|
||||
source_node.hardware_state == HardwareState.ON
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
dest_valid = False
|
||||
elif dest_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.ip_address
|
||||
+ ", dest: "
|
||||
+ dest_node.ip_address
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
|
||||
# Check whether both the source and destination are valid, and there's no ACL block
|
||||
if source_valid and dest_valid and not acl_block:
|
||||
# Load up the link(s) with the traffic
|
||||
|
||||
if _VERBOSE:
|
||||
print("Source, Dest and ACL valid")
|
||||
|
||||
# Get the shortest path (i.e. nodes) between source and destination
|
||||
path_node_list = shortest_path(network, source_node, dest_node)
|
||||
path_node_list_length = len(path_node_list)
|
||||
path_valid = True
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
# Now apply the new loads to the links
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
node_pol: The node pattern of life to apply
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying Node PoL")
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
node_id = node_instruction.get_node_id()
|
||||
node_pol_type = node_instruction.get_node_pol_type()
|
||||
service_name = node_instruction.get_service_name()
|
||||
state = node_instruction.get_state()
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
node = nodes[node_id]
|
||||
|
||||
if node_pol_type == NodePOLType.OPERATING:
|
||||
# Change hardware state
|
||||
node.hardware_state = state
|
||||
elif node_pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_software_state_if_not_compromised(state)
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ServiceNode):
|
||||
node.set_service_state_if_not_compromised(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
@@ -1,147 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
|
||||
|
||||
class IER(object):
|
||||
"""Information Exchange Requirement class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_load: int,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_source_node_id: str,
|
||||
_dest_node_id: str,
|
||||
_mission_criticality: int,
|
||||
_running: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an Information Exchange Request.
|
||||
|
||||
:param _id: The IER id
|
||||
:param _start_step: The step when this IER should start
|
||||
:param _end_step: The step when this IER should end
|
||||
:param _load: The load this IER should put on a link (bps)
|
||||
:param _protocol: The protocol of this IER
|
||||
:param _port: The port this IER runs on
|
||||
:param _source_node_id: The source node ID
|
||||
:param _dest_node_id: The destination node ID
|
||||
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
:param _running: Indicates whether the IER is currently running
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.source_node_id: str = _source_node_id
|
||||
self.dest_node_id: str = _dest_node_id
|
||||
self.load: int = _load
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
self.mission_criticality: int = _mission_criticality
|
||||
self.running: bool = _running
|
||||
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
"""
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
"""
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self) -> bool:
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
"""
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value: bool) -> None:
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
"""
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self) -> int:
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
"""
|
||||
return self.mission_criticality
|
||||
@@ -1,353 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(
|
||||
network: MultiGraph,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
links: Dict[str, Link],
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
nodes: The nodes within the environment
|
||||
links: The links within the environment
|
||||
iers: The red agent IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number.
|
||||
"""
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
protocol = ier_value.get_protocol()
|
||||
port = ier_value.get_port()
|
||||
load = ier_value.get_load()
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
dest_valid = True
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
source_valid = False
|
||||
elif source_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
# TO DO
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||
# to change according to duck typing.
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.hardware_state == HardwareState.ON:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
dest_valid = False
|
||||
elif dest_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.hardware_state == HardwareState.ON:
|
||||
if dest_node.has_service(protocol):
|
||||
# We don't care what state the destination service is in for an IER
|
||||
dest_valid = True
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.ip_address
|
||||
+ ", dest: "
|
||||
+ dest_node.ip_address
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
|
||||
# Check whether both the source and destination are valid, and there's no ACL block
|
||||
if source_valid and dest_valid and not acl_block:
|
||||
# Load up the link(s) with the traffic
|
||||
|
||||
if _VERBOSE:
|
||||
print("Source, Dest and ACL valid")
|
||||
|
||||
# Get the shortest path (i.e. nodes) between source and destination
|
||||
path_node_list = shortest_path(network, source_node, dest_node)
|
||||
path_node_list_length = len(path_node_list)
|
||||
path_valid = True
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
# We're assuming here that red agents can get past switches that are patching
|
||||
for node in path_node_list:
|
||||
if node.hardware_state != HardwareState.ON:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
# Now apply the new loads to the links
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
print("Red IER was allowed to run in step " + str(step))
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def apply_red_agent_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
iers: Dict[str, IER],
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
iers: The red agent IERs
|
||||
node_pol: The red agent node pattern of life to apply
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying Node Red Agent PoL")
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
target_node_id = node_instruction.get_target_node_id()
|
||||
initiator = node_instruction.get_initiator()
|
||||
pol_type = node_instruction.get_pol_type()
|
||||
service_name = node_instruction.get_service_name()
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
|
||||
passed_checks = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
target_node: NodeUnion = nodes[target_node_id]
|
||||
|
||||
# check if the initiator type is a str, and if so, cast it as
|
||||
# NodePOLInitiator
|
||||
if isinstance(initiator, str):
|
||||
initiator = NodePOLInitiator[initiator]
|
||||
|
||||
# Based the action taken on the initiator type
|
||||
if initiator == NodePOLInitiator.DIRECT:
|
||||
# No conditions required, just apply the change
|
||||
passed_checks = True
|
||||
elif initiator == NodePOLInitiator.IER:
|
||||
# Need to check there is a red IER incoming
|
||||
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
|
||||
elif initiator == NodePOLInitiator.SERVICE:
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
if source_node.has_service(source_node_service_name):
|
||||
if (
|
||||
source_node.get_service_state(source_node_service_name)
|
||||
== SoftwareState[source_node_service_state_value]
|
||||
):
|
||||
passed_checks = True
|
||||
else:
|
||||
# Do nothing, no matching state value
|
||||
pass
|
||||
else:
|
||||
# Do nothing, service not on this node
|
||||
pass
|
||||
else:
|
||||
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
|
||||
|
||||
# Only apply the PoL if the checks have passed (based on the initiator type)
|
||||
if passed_checks:
|
||||
# Apply the change
|
||||
if pol_type == NodePOLType.OPERATING:
|
||||
# Change hardware state
|
||||
target_node.hardware_state = state
|
||||
elif pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.software_state = state
|
||||
elif pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
if isinstance(target_node, ServiceNode):
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
|
||||
def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool:
|
||||
"""Checks if the RED IER is incoming.
|
||||
|
||||
:param node: Destination node of the IER
|
||||
:type node: NodeUnion
|
||||
:param iers: Directory of IERs
|
||||
:type iers: Dict[str,IER]
|
||||
:param node_pol_type: Type of Pattern-Of-Life
|
||||
:type node_pol_type: NodePOLType
|
||||
:return: Whether the RED IER is incoming.
|
||||
:rtype: bool
|
||||
"""
|
||||
node_id = node.node_id
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if (
|
||||
node_pol_type == NodePOLType.OPERATING
|
||||
or node_pol_type == NodePOLType.OS
|
||||
or node_pol_type == NodePOLType.FILE
|
||||
):
|
||||
# It's looking to change hardware state, file system or SoftwareState, so valid
|
||||
return True
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Check if the service is present on the node and running
|
||||
ier_protocol = ier_value.get_protocol()
|
||||
if isinstance(node, ServiceNode):
|
||||
if node.has_service(ier_protocol):
|
||||
if node.service_running(ier_protocol):
|
||||
# Matching service is present and running, so valid
|
||||
return True
|
||||
else:
|
||||
# Service is present, but not running
|
||||
return False
|
||||
else:
|
||||
# Service is not present
|
||||
return False
|
||||
else:
|
||||
# Not a service node
|
||||
return False
|
||||
else:
|
||||
# Shouldn't get here - instruction type is undefined
|
||||
return False
|
||||
else:
|
||||
# The IER destination is not this node, or the IER is not running
|
||||
return False
|
||||
@@ -1,228 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Tuple, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
|
||||
# from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession class.
|
||||
|
||||
Provides a single learning and evaluation entry point for all training and lay down configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
# set load_session to true
|
||||
self.is_load_session = True
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
|
||||
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path, legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path,
|
||||
self.session_path,
|
||||
self.legacy_training_config,
|
||||
self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
# elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
# _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# # Ray RLlib Agent
|
||||
# self._agent_session = RLlibAgent(
|
||||
# self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
# )
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
raise ValueError
|
||||
|
||||
self.session_path: Path = self._agent_session.session_path
|
||||
self.timestamp_str: str = self._agent_session.timestamp_str
|
||||
self.learning_path: Path = self._agent_session.learning_path
|
||||
self.evaluation_path: Path = self._agent_session.evaluation_path
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(**kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the learn av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the eval av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the learn all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the eval all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def metadata_file_as_dict(self) -> Dict[str, Any]:
|
||||
"""Read the session_metadata.json file and return as a dict."""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
return json.load(file)
|
||||
@@ -1,14 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Record data of the system's state and agent's observations and actions."""
|
||||
@@ -1,102 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import AgentIdentifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""Transaction class."""
|
||||
|
||||
def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int) -> None:
|
||||
"""
|
||||
Transaction constructor.
|
||||
|
||||
:param agent_identifier: An identifier for the agent in use
|
||||
:param episode_number: The episode number
|
||||
:param step_number: The step number
|
||||
"""
|
||||
self.timestamp: datetime = datetime.now()
|
||||
"The datetime of the transaction"
|
||||
self.agent_identifier: AgentIdentifier = agent_identifier
|
||||
"The agent identifier"
|
||||
self.episode_number: int = episode_number
|
||||
"The episode number"
|
||||
self.step_number: int = step_number
|
||||
"The step number"
|
||||
self.obs_space: "spaces.Space" = None
|
||||
"The observation space (pre)"
|
||||
self.obs_space_pre: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space before any actions are taken"
|
||||
self.obs_space_post: Optional[Union["np.ndarray", Tuple["np.ndarray"]]] = None
|
||||
"The observation space after any actions are taken"
|
||||
self.reward: Optional[float] = None
|
||||
"The reward value"
|
||||
self.action_space: Optional[int] = None
|
||||
"The action space invoked by the agent"
|
||||
self.obs_space_description: Optional[List[str]] = None
|
||||
"The env observation space description"
|
||||
|
||||
def as_csv_data(self) -> Tuple[List, List]:
|
||||
"""
|
||||
Converts the Transaction to a csv data row and provides a header.
|
||||
|
||||
:return: A tuple consisting of (header, data).
|
||||
"""
|
||||
if isinstance(self.action_space, int):
|
||||
action_length = self.action_space
|
||||
else:
|
||||
action_length = self.action_space.size
|
||||
|
||||
# Create the action space headers array
|
||||
action_header = []
|
||||
for x in range(action_length):
|
||||
action_header.append("AS_" + str(x))
|
||||
|
||||
# Open up a csv file
|
||||
header = ["Timestamp", "Episode", "Step", "Reward"]
|
||||
header = header + action_header + self.obs_space_description
|
||||
|
||||
row = [
|
||||
str(self.timestamp),
|
||||
str(self.episode_number),
|
||||
str(self.step_number),
|
||||
str(self.reward),
|
||||
]
|
||||
row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist()
|
||||
return header, row
|
||||
|
||||
|
||||
def _turn_action_space_to_array(action_space: Union[int, List[int]]) -> List[str]:
|
||||
"""
|
||||
Turns action space into a string array so it can be saved to csv.
|
||||
|
||||
:param action_space: The action space
|
||||
:return: The action space as an array of strings
|
||||
"""
|
||||
if isinstance(action_space, list):
|
||||
return [str(i) for i in action_space]
|
||||
else:
|
||||
return [str(action_space)]
|
||||
|
||||
|
||||
def _turn_obs_space_to_array(obs_space: "np.ndarray", obs_assets: int, obs_features: int) -> List[str]:
|
||||
"""
|
||||
Turns observation space into a string array so it can be saved to csv.
|
||||
|
||||
:param obs_space: The observation space
|
||||
:param obs_assets: The number of assets (i.e. nodes or links) in the observation space
|
||||
:param obs_features: The number of features associated with the asset
|
||||
:return: The observation space as an array of strings
|
||||
"""
|
||||
return_array = []
|
||||
for x in range(obs_assets):
|
||||
for y in range(obs_features):
|
||||
return_array.append(str(obs_space[x][y]))
|
||||
|
||||
return return_array
|
||||
Reference in New Issue
Block a user