#1962: merge dev into branch + fix minor diffs + ensure that imports pull from src
This commit is contained in:
@@ -1 +1 @@
|
||||
2.0.0
|
||||
3.0.0a1
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Access Control List. Models firewall functionality."""
|
||||
@@ -1,198 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
import logging
|
||||
from typing import Dict, Final, List, Union
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessControlList:
|
||||
"""Access Control List class."""
|
||||
|
||||
def __init__(self, implicit_permission: RulePermissionType, max_acl_rules: int) -> None:
|
||||
"""Init."""
|
||||
# Implicit ALLOW or DENY firewall spec
|
||||
self.acl_implicit_permission = implicit_permission
|
||||
# Implicit rule in ACL list
|
||||
if self.acl_implicit_permission == RulePermissionType.DENY:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.DENY, "ANY", "ANY", "ANY", "ANY")
|
||||
elif self.acl_implicit_permission == RulePermissionType.ALLOW:
|
||||
self.acl_implicit_rule = ACLRule(RulePermissionType.ALLOW, "ANY", "ANY", "ANY", "ANY")
|
||||
else:
|
||||
raise ValueError(f"implicit permission must be ALLOW or DENY, got {self.acl_implicit_permission}")
|
||||
|
||||
# Maximum number of ACL Rules in ACL
|
||||
self.max_acl_rules: int = max_acl_rules
|
||||
# A list of ACL Rules
|
||||
self._acl: List[Union[ACLRule, None]] = [None] * (self.max_acl_rules - 1)
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Union[ACLRule, None]]:
|
||||
"""Public access method for private _acl."""
|
||||
return self._acl + [self.acl_implicit_rule]
|
||||
|
||||
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
|
||||
"""Checks for IP address matches.
|
||||
|
||||
:param _rule: The rule object to check
|
||||
:type _rule: ACLRule
|
||||
:param _source_ip_address: Source IP address to compare
|
||||
:type _source_ip_address: str
|
||||
:param _dest_ip_address: Destination IP address to compare
|
||||
:type _dest_ip_address: str
|
||||
:return: True if there is a match, otherwise False.
|
||||
:rtype: bool
|
||||
"""
|
||||
if (
|
||||
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address)
|
||||
or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY")
|
||||
or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY")
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool:
|
||||
"""
|
||||
Checks for rules that block a protocol / port.
|
||||
|
||||
Args:
|
||||
_source_ip_address: the source IP address to check
|
||||
_dest_ip_address: the destination IP address to check
|
||||
_protocol: the protocol to check
|
||||
_port: the port to check
|
||||
|
||||
Returns:
|
||||
Indicates block if all conditions are satisfied.
|
||||
"""
|
||||
for rule in self.acl:
|
||||
if isinstance(rule, ACLRule):
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY"
|
||||
):
|
||||
# There's a matching rule. Get the permission
|
||||
if rule.get_permission() == RulePermissionType.DENY:
|
||||
return True
|
||||
elif rule.get_permission() == RulePermissionType.ALLOW:
|
||||
return False
|
||||
|
||||
# If there has been no rule to allow the IER through, it will return a blocked signal by default
|
||||
return True
|
||||
|
||||
def add_rule(
|
||||
self,
|
||||
_permission: RulePermissionType,
|
||||
_source_ip: str,
|
||||
_dest_ip: str,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_position: str,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
_position: position to insert ACL rule into ACL list (starting from index 1 and NOT 0)
|
||||
"""
|
||||
try:
|
||||
position_index = int(_position)
|
||||
except TypeError:
|
||||
_LOGGER.info(f"Position {_position} could not be converted to integer.")
|
||||
return
|
||||
|
||||
new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
# Checks position is in correct range
|
||||
if self.max_acl_rules - 1 > position_index > -1:
|
||||
try:
|
||||
_LOGGER.info(f"Position {position_index} is valid.")
|
||||
# Check to see Agent will not overwrite current ACL in ACL list
|
||||
if self._acl[position_index] is None:
|
||||
_LOGGER.info(f"Inserting rule {new_rule} at position {position_index}")
|
||||
# Adds rule
|
||||
self._acl[position_index] = new_rule
|
||||
else:
|
||||
# Cannot overwrite it
|
||||
_LOGGER.info(f"Error: inserting rule at non-empty position {position_index}")
|
||||
return
|
||||
except Exception:
|
||||
_LOGGER.info(f"New Rule could NOT be added to list at position {position_index}.")
|
||||
else:
|
||||
_LOGGER.info(f"Position {position_index} is an invalid/overwrites implicit firewall rule")
|
||||
|
||||
def remove_rule(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Removes a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
"""
|
||||
rule_to_delete = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
delete_rule_hash = hash(rule_to_delete)
|
||||
|
||||
for index in range(0, len(self._acl)):
|
||||
if isinstance(self._acl[index], ACLRule) and hash(self._acl[index]) == delete_rule_hash:
|
||||
self._acl[index] = None
|
||||
|
||||
def remove_all_rules(self) -> None:
|
||||
"""Removes all rules."""
|
||||
for i in range(len(self._acl)):
|
||||
self._acl[i] = None
|
||||
|
||||
def get_dictionary_hash(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> int:
|
||||
"""
|
||||
Produces a hash value for a rule.
|
||||
|
||||
Args:
|
||||
_permission: the permission value (e.g. "ALLOW" or "DENY")
|
||||
_source_ip: the source IP address
|
||||
_dest_ip: the destination IP address
|
||||
_protocol: the protocol
|
||||
_port: the port
|
||||
|
||||
Returns:
|
||||
Hash value based on rule parameters.
|
||||
"""
|
||||
rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port))
|
||||
hash_value = hash(rule)
|
||||
return hash_value
|
||||
|
||||
def get_relevant_rules(
|
||||
self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all ACL rules that relate to the given arguments.
|
||||
|
||||
:param _source_ip_address: the source IP address to check
|
||||
:param _dest_ip_address: the destination IP address to check
|
||||
:param _protocol: the protocol to check
|
||||
:param _port: the port to check
|
||||
:return: Dictionary of all ACL rules that relate to the given arguments
|
||||
:rtype: Dict[int, ACLRule]
|
||||
"""
|
||||
relevant_rules = {}
|
||||
for rule in self.acl:
|
||||
if self.check_address_match(rule, _source_ip_address, _dest_ip_address):
|
||||
if (rule.get_protocol() == _protocol or rule.get_protocol() == "ANY" or _protocol == "ANY") and (
|
||||
str(rule.get_port()) == str(_port) or rule.get_port() == "ANY" or str(_port) == "ANY"
|
||||
):
|
||||
# There's a matching rule.
|
||||
relevant_rules[self._acl.index(rule)] = rule
|
||||
|
||||
return relevant_rules
|
||||
@@ -1,87 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A class that implements an access control list rule."""
|
||||
from primaite.common.enums import RulePermissionType
|
||||
|
||||
|
||||
class ACLRule:
|
||||
"""Access Control List Rule class."""
|
||||
|
||||
def __init__(
|
||||
self, _permission: RulePermissionType, _source_ip: str, _dest_ip: str, _protocol: str, _port: str
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL Rule.
|
||||
|
||||
:param _permission: The permission (ALLOW or DENY)
|
||||
:param _source_ip: The source IP address
|
||||
:param _dest_ip: The destination IP address
|
||||
:param _protocol: The rule protocol
|
||||
:param _port: The rule port
|
||||
"""
|
||||
self.permission: RulePermissionType = _permission
|
||||
self.source_ip: str = _source_ip
|
||||
self.dest_ip: str = _dest_ip
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Override the hash function.
|
||||
|
||||
Returns:
|
||||
Returns hash of core parameters.
|
||||
"""
|
||||
return hash(
|
||||
(
|
||||
self.permission,
|
||||
self.source_ip,
|
||||
self.dest_ip,
|
||||
self.protocol,
|
||||
self.port,
|
||||
)
|
||||
)
|
||||
|
||||
def get_permission(self) -> str:
|
||||
"""
|
||||
Gets the permission attribute.
|
||||
|
||||
Returns:
|
||||
Returns permission attribute
|
||||
"""
|
||||
return self.permission
|
||||
|
||||
def get_source_ip(self) -> str:
|
||||
"""
|
||||
Gets the source IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns source IP address attribute
|
||||
"""
|
||||
return self.source_ip
|
||||
|
||||
def get_dest_ip(self) -> str:
|
||||
"""
|
||||
Gets the desintation IP address attribute.
|
||||
|
||||
Returns:
|
||||
Returns destination IP address attribute
|
||||
"""
|
||||
return self.dest_ip
|
||||
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets the protocol attribute.
|
||||
|
||||
Returns:
|
||||
Returns protocol attribute
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets the port attribute.
|
||||
|
||||
Returns:
|
||||
Returns port attribute
|
||||
"""
|
||||
return self.port
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Common interface between RL agents from different libraries and PrimAITE."""
|
||||
@@ -1,319 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get the directory path the session will output to.
|
||||
|
||||
This is set in the format of:
|
||||
~/primaite/2.0.0/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentSessionABC(ABC):
|
||||
"""
|
||||
An ABC that manages training and/or evaluation of agents in PrimAITE.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
If training configuration and laydown configuration are provided with a session path,
|
||||
the session path will be used.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
# initialise variables
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
|
||||
# convert session to path
|
||||
if session_path is not None:
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
# load session
|
||||
self.load(session_path)
|
||||
else:
|
||||
# set training config path
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(
|
||||
self._training_config_path, legacy_file=legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = str(uuid4())
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
"""The session timestamp as a string."""
|
||||
return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@property
|
||||
def learning_path(self) -> Path:
|
||||
"""The learning outputs path."""
|
||||
path = self.session_path / "learning"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def evaluation_path(self) -> Path:
|
||||
"""The evaluation outputs path."""
|
||||
path = self.session_path / "evaluation"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def checkpoints_path(self) -> Path:
|
||||
"""The Session checkpoints path."""
|
||||
path = self.learning_path / "checkpoints"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def uuid(self) -> str:
|
||||
"""The Agent Session UUID."""
|
||||
return self._uuid
|
||||
|
||||
def _write_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Write the ``session_metadata.json`` file.
|
||||
|
||||
Creates a ``session_metadata.json`` in the ``session_path`` directory
|
||||
and adds the following key/value pairs:
|
||||
|
||||
- uuid: The UUID assigned to the session upon instantiation.
|
||||
- start_datetime: The date & time the session started in iso format.
|
||||
- end_datetime: NULL.
|
||||
- total_episodes: NULL.
|
||||
- total_time_steps: NULL.
|
||||
- env:
|
||||
- training_config:
|
||||
- All training config items
|
||||
- lay_down_config:
|
||||
- All lay down config items
|
||||
|
||||
"""
|
||||
metadata_dict = {
|
||||
"uuid": self.uuid,
|
||||
"start_datetime": self.session_timestamp.isoformat(),
|
||||
"end_datetime": None,
|
||||
"learning": {"total_episodes": None, "total_time_steps": None},
|
||||
"evaluation": {"total_episodes": None, "total_time_steps": None},
|
||||
"env": {
|
||||
"training_config": self._training_config.to_dict(json_serializable=True),
|
||||
"lay_down_config": self._lay_down_config,
|
||||
},
|
||||
}
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Writing Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished writing session metadata file")
|
||||
|
||||
def _update_session_metadata_file(self) -> None:
|
||||
"""
|
||||
Update the ``session_metadata.json`` file.
|
||||
|
||||
Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
with the following key/value pairs:
|
||||
|
||||
- end_datetime: The date & time the session ended in iso format.
|
||||
- total_episodes: The total number of training episodes completed.
|
||||
- total_time_steps: The total number of training time steps completed.
|
||||
"""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
metadata_dict = json.load(file)
|
||||
|
||||
metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
if not self.is_eval:
|
||||
metadata_dict["learning"]["total_episodes"] = self._env.actual_episode_count # noqa
|
||||
metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
else:
|
||||
metadata_dict["evaluation"]["total_episodes"] = self._env.actual_episode_count # noqa
|
||||
metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa
|
||||
|
||||
filepath = self.session_path / "session_metadata.json"
|
||||
_LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
with open(filepath, "w") as file:
|
||||
json.dump(metadata_dict, file)
|
||||
_LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self) -> None:
|
||||
_LOGGER.info(
|
||||
"Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})"
|
||||
)
|
||||
_LOGGER.info(f"The output directory for this session is: {self.session_path}")
|
||||
self._write_session_metadata_file()
|
||||
self._can_learn = True
|
||||
self._can_evaluate = False
|
||||
|
||||
@abstractmethod
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_learn:
|
||||
_LOGGER.info("Finished learning")
|
||||
_LOGGER.debug("Writing transactions")
|
||||
self._update_session_metadata_file()
|
||||
self._can_evaluate = True
|
||||
self.is_eval = False
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
if self._can_evaluate:
|
||||
self._update_session_metadata_file()
|
||||
self.is_eval = True
|
||||
self._plot_av_reward_per_episode(learning_session=False)
|
||||
_LOGGER.info("Finished evaluation")
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def load(self, path: Union[str, Path]) -> None:
|
||||
"""Load an agent from file."""
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = laydown_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._env.episode_av_reward_writer.close() # noqa
|
||||
self._env.transaction_writer.close() # noqa
|
||||
|
||||
def _plot_av_reward_per_episode(self, learning_session: bool = True) -> None:
|
||||
# self.close()
|
||||
title = f"PrimAITE Session {self.timestamp_str} "
|
||||
subtitle = str(self._training_config)
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
image_file = f"average_reward_per_episode_{self.timestamp_str}.png"
|
||||
if learning_session:
|
||||
title += "(Learning)"
|
||||
path = self.learning_path / csv_file
|
||||
image_path = self.learning_path / image_file
|
||||
else:
|
||||
title += "(Evaluation)"
|
||||
path = self.evaluation_path / csv_file
|
||||
image_path = self.evaluation_path / image_file
|
||||
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
@@ -1,118 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path] = None) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
@@ -1,515 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
transform_action_acl_enum,
|
||||
transform_change_obs_readable,
|
||||
)
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardCodedAgentView
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
|
||||
class HardCodedACLAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic ACL agent."""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
|
||||
# Basic view action using only the current observation
|
||||
return self._calculate_action_basic_view(obs)
|
||||
else:
|
||||
# full view action using observation space, action
|
||||
# history and reward feedback
|
||||
return self._calculate_action_full_view(obs)
|
||||
|
||||
def get_blocked_green_iers(
|
||||
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[str, IER]:
|
||||
"""Get blocked green IERs.
|
||||
|
||||
:param green_iers: Green IERs to check for being
|
||||
:type green_iers: Dict[str, IER]
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
|
||||
:rtype: Dict[str, IER]
|
||||
"""
|
||||
blocked_green_iers = {}
|
||||
|
||||
for green_ier_id, green_ier in green_iers.items():
|
||||
source_node_id = green_ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = green_ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = green_ier.get_protocol() # e.g. 'TCP'
|
||||
port = green_ier.get_port()
|
||||
|
||||
# Can be blocked by an ACL or by default (no allow rule exists)
|
||||
if acl.is_blocked(source_node_address, dest_node_address, protocol, port):
|
||||
blocked_green_iers[green_ier_id] = green_ier
|
||||
|
||||
return blocked_green_iers
|
||||
|
||||
def get_matching_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get list of ACL rules which are relevant to an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
source_node_id = ier.get_source_node_id()
|
||||
source_node_address = nodes[source_node_id].ip_address
|
||||
dest_node_id = ier.get_dest_node_id()
|
||||
dest_node_address = nodes[dest_node_id].ip_address
|
||||
protocol = ier.get_protocol() # e.g. 'TCP'
|
||||
port = ier.get_port()
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_blocking_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""
|
||||
Get blocking ACL rules for an IER.
|
||||
|
||||
.. warning::
|
||||
Can return empty dict but IER can still be blocked by default
|
||||
(No ALLOW rule, therefore blocked).
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
blocked_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
blocked_rules[rule_key] = rule_value
|
||||
|
||||
return blocked_rules
|
||||
|
||||
def get_allow_acl_rules_for_ier(
|
||||
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Get all allowing ACL rules for an IER.
|
||||
|
||||
:param ier: Information Exchange Request to query against the ACL list
|
||||
:type ier: IER
|
||||
:param acl: Firewall rules
|
||||
:type acl: AccessControlList
|
||||
:param nodes: Nodes in the network
|
||||
:type nodes: Dict[str,NodeUnion]
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_matching_acl_rules(
|
||||
self,
|
||||
source_node_id: str,
|
||||
dest_node_id: str,
|
||||
protocol: str,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""Filter ACL rules to only those which are relevant to the specified nodes.
|
||||
|
||||
:param source_node_id: Source node
|
||||
:type source_node_id: str
|
||||
:param dest_node_id: Destination nodes
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: str
|
||||
:param port: Network port
|
||||
:type port: str
|
||||
:param acl: Access Control list which will be filtered
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The environment's node directory.
|
||||
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
|
||||
:param services_list: List of services registered for the environment.
|
||||
:type services_list: List[str]
|
||||
:return: Filtered version of 'acl'
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
if source_node_id != "ANY":
|
||||
source_node_address = nodes[str(source_node_id)].ip_address
|
||||
else:
|
||||
source_node_address = source_node_id
|
||||
|
||||
if dest_node_id != "ANY":
|
||||
dest_node_address = nodes[str(dest_node_id)].ip_address
|
||||
else:
|
||||
dest_node_address = dest_node_id
|
||||
|
||||
if protocol != "ANY":
|
||||
protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services
|
||||
# TODO: This should throw an error because protocol is a string
|
||||
|
||||
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
|
||||
return matching_rules
|
||||
|
||||
def get_allow_acl_rules(
|
||||
self,
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List ALLOW rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "ALLOW":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def get_deny_acl_rules(
|
||||
self,
|
||||
source_node_id: int,
|
||||
dest_node_id: str,
|
||||
protocol: int,
|
||||
port: str,
|
||||
acl: AccessControlList,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
services_list: List[str],
|
||||
) -> Dict[int, ACLRule]:
|
||||
"""List DENY rules relating to specified nodes.
|
||||
|
||||
:param source_node_id: Source node id
|
||||
:type source_node_id: int
|
||||
:param dest_node_id: Destination node
|
||||
:type dest_node_id: str
|
||||
:param protocol: Network protocol
|
||||
:type protocol: int
|
||||
:param port: Port
|
||||
:type port: str
|
||||
:param acl: Firewall ruleset which is applied to the network
|
||||
:type acl: AccessControlList
|
||||
:param nodes: The simulation's node store
|
||||
:type nodes: Dict[str, NodeUnion]
|
||||
:param services_list: Services list
|
||||
:type services_list: List[str]
|
||||
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
|
||||
desination nodes
|
||||
:rtype: Dict[str, ACLRule]
|
||||
"""
|
||||
matching_rules = self.get_matching_acl_rules(
|
||||
source_node_id,
|
||||
dest_node_id,
|
||||
protocol,
|
||||
port,
|
||||
acl,
|
||||
nodes,
|
||||
services_list,
|
||||
)
|
||||
|
||||
allowed_rules = {}
|
||||
for rule_key, rule_value in matching_rules.items():
|
||||
if rule_value.get_permission() == "DENY":
|
||||
allowed_rules[rule_key] = rule_value
|
||||
|
||||
return allowed_rules
|
||||
|
||||
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Knowledge of just the observation space is insufficient for a perfect solution, as we need to know:
|
||||
|
||||
- Which ACL rules already exist, - otherwise:
|
||||
- The agent would perminently get stuck in a loop of performing the same action over and over.
|
||||
(best action is to block something, but its already blocked but doesn't know this)
|
||||
- The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule,
|
||||
if it doesnt know what rules exist)
|
||||
- The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example
|
||||
in the default config one of the green IERs is blocked by default, but it has no way of knowing this
|
||||
based on the observation space. Additionally, potentially in the future, once a node state
|
||||
has been fixed (no longer compromised), it needs a way to know it should reallow traffic.
|
||||
A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this.
|
||||
|
||||
There doesn't seem like there's much that can be done if an Operating or OS State is compromised
|
||||
|
||||
If a service node becomes compromised there's a decision to make - do we block that service?
|
||||
Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED
|
||||
Cons: Will block a green IER, decreasing the reward
|
||||
We decide to block the service.
|
||||
|
||||
Potentially a better solution (for the reward) would be to block the incomming traffic from compromised
|
||||
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
|
||||
an overwhelmed state, so we don't do this.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
# obs = convert_to_old_obs(obs)
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, _, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# 1. Check if node is compromised. If so we want to block its outwards services
|
||||
# a. If it is comprimised check if there's an allow rule we should delete.
|
||||
# cons: might delete a multi-rule from any source node (ANY -> x)
|
||||
# b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate
|
||||
# c. OPTIONAL (no allow rule = blocked): Add a DENY rule
|
||||
found_action = False
|
||||
for service_num, service_states in enumerate(s):
|
||||
for x, service_state in enumerate(service_states):
|
||||
if service_state == "COMPROMISED":
|
||||
action_source_id = x + 1 # +1 as 0 is any
|
||||
action_destination_id = "ANY"
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
action_port = "ANY"
|
||||
|
||||
allow_rules = self.get_allow_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
deny_rules = self.get_deny_acl_rules(
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
self._env.acl,
|
||||
self._env.nodes,
|
||||
self._env.services_list,
|
||||
)
|
||||
if len(allow_rules) > 0:
|
||||
# Check if there's an allow rule we should delete
|
||||
rule = list(allow_rules.values())[0]
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = rule.get_source_ip()
|
||||
action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes))
|
||||
action_destination_ip = rule.get_dest_ip()
|
||||
action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes))
|
||||
action_protocol_name = rule.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = rule.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
elif len(deny_rules) > 0:
|
||||
# TODO OPTIONAL
|
||||
# If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need
|
||||
# to create another
|
||||
# Check to see if the DENY rule really blocks everything (ANY) or just a specific rule
|
||||
continue
|
||||
else:
|
||||
# TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked
|
||||
action_decision = "CREATE"
|
||||
action_permission = "DENY"
|
||||
break
|
||||
if found_action:
|
||||
break
|
||||
|
||||
# 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and
|
||||
# add an Allow rule if the green IER is being blocked.
|
||||
# a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule):
|
||||
# If there's a DENY rule delete it if:
|
||||
# - There isn't already a deny rule
|
||||
# - It doesnt allows a comprimised node to become operational.
|
||||
# b. Add an ALLOW rule if:
|
||||
# - There isn't already an allow rule
|
||||
# - It doesnt allows a comprimised node to become operational
|
||||
|
||||
if not found_action:
|
||||
# Which Green IERS are blocked
|
||||
blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes)
|
||||
for ier_key, ier in blocked_green_iers.items():
|
||||
# Which ALLOW rules are allowing this IER (none)
|
||||
allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes)
|
||||
|
||||
# If there are no blocking rules, it may be being blocked by default
|
||||
# If there is already an allow rule
|
||||
node_id_to_check = int(ier.get_source_node_id())
|
||||
service_name_to_check = ier.get_protocol()
|
||||
service_id_to_check = self._env.services_list.index(service_name_to_check)
|
||||
|
||||
# Service state of the the source node in the ier
|
||||
service_state = s[service_id_to_check][node_id_to_check - 1]
|
||||
|
||||
if len(allowing_rules) == 0 and service_state != "COMPROMISED":
|
||||
action_decision = "CREATE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_id = int(ier.get_source_node_id())
|
||||
action_destination_id = int(ier.get_dest_node_id())
|
||||
action_protocol_name = ier.get_protocol()
|
||||
action_protocol = (
|
||||
self._env.services_list.index(action_protocol_name) + 1
|
||||
) # convert name e.g. 'TCP' to index
|
||||
action_port_name = ier.get_port()
|
||||
action_port = (
|
||||
self._env.ports_list.index(action_port_name) + 1
|
||||
) # convert port name e.g. '80' to index
|
||||
|
||||
found_action = True
|
||||
break
|
||||
|
||||
if found_action:
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_id,
|
||||
action_destination_id,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
else:
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, self._env.action_dict)
|
||||
return action
|
||||
|
||||
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good acl-based action for the blue agent to take.
|
||||
|
||||
Uses ONLY information from the current observation with NO knowledge
|
||||
of previous actions taken and NO reward feedback.
|
||||
|
||||
We rely on randomness to select the precise action, as we want to
|
||||
block all traffic originating from a compromised node, without being
|
||||
able to tell:
|
||||
1. Which ACL rules already exist
|
||||
2. Which actions the agent has already tried.
|
||||
|
||||
There is a high probability that the correct rule will not be deleted
|
||||
before the state becomes overwhelmed.
|
||||
|
||||
Currently, a deny rule does not overwrite an allow rule. The allow
|
||||
rules must be deleted.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, _, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links)
|
||||
for service_num, service_states in enumerate(s):
|
||||
comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"]
|
||||
if len(comprimised_states) == 0:
|
||||
# No states are COMPROMISED, try the next service
|
||||
continue
|
||||
|
||||
compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any
|
||||
action_decision = "DELETE"
|
||||
action_permission = "ALLOW"
|
||||
action_source_ip = compromised_node
|
||||
# Randomly select a destination ID to block
|
||||
action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"])
|
||||
action_destination_ip = (
|
||||
int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip
|
||||
)
|
||||
action_protocol = service_num + 1 # +1 as 0 is any
|
||||
# Randomly select a port to block
|
||||
# Bad assumption that number of protocols equals number of ports
|
||||
# AND no rules exist with an ANY port
|
||||
action_port = np.random.choice(list(range(1, len(s) + 1)))
|
||||
|
||||
action = [
|
||||
action_decision,
|
||||
action_permission,
|
||||
action_source_ip,
|
||||
action_destination_ip,
|
||||
action_protocol,
|
||||
action_port,
|
||||
]
|
||||
action = transform_action_acl_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good/useful action has been found, just perform a nothing action
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, action_dict)
|
||||
return nothing_action
|
||||
@@ -1,125 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
class HardCodedNodeAgent(HardCodedAgentSessionABC):
|
||||
"""An Agent Session class that implements a deterministic Node agent."""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Calculate a good node-based action for the blue agent to take.
|
||||
|
||||
:param obs: current observation from the gym environment
|
||||
:type obs: np.ndarray
|
||||
:return: Optimal action to take in the environment (chosen from the discrete action space)
|
||||
:rtype: int
|
||||
"""
|
||||
action_dict = self._env.action_dict
|
||||
r_obs = transform_change_obs_readable(obs)
|
||||
_, o, os, *s = r_obs
|
||||
|
||||
if len(r_obs) == 4: # only 1 service
|
||||
s = [*s]
|
||||
|
||||
# Check in order of most important states (order doesn't currently
|
||||
# matter, but it probably should)
|
||||
# First see if any OS states are compromised
|
||||
for x, os_state in enumerate(os):
|
||||
if os_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OS"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = 0 # does nothing isn't relevant for os
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, see if any Services are compromised
|
||||
# We fix the compromised state before overwhelemd state,
|
||||
# If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "COMPROMISED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Next, See if any services are overwhelmed
|
||||
# perhaps this should be fixed automatically when the compromised PCs issues are also resolved
|
||||
# Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs
|
||||
|
||||
for service_num, service in enumerate(s):
|
||||
for x, service_state in enumerate(service):
|
||||
if service_state == "OVERWHELMED":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "SERVICE"
|
||||
property_action = "PATCHING"
|
||||
action_service_index = service_num
|
||||
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# Finally, turn on any off nodes
|
||||
for x, operating_state in enumerate(o):
|
||||
if os_state == "OFF":
|
||||
action_node_id = x + 1
|
||||
action_node_property = "OPERATING"
|
||||
property_action = "ON" # Why reset it when we can just turn it on
|
||||
action_service_index = 0 # does nothing isn't relevant for operating state
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
# TODO: transform_action_node_enum takes only one argument, not sure why two are given here.
|
||||
action = transform_action_node_enum(action, action_dict)
|
||||
action = get_new_action(action, action_dict)
|
||||
# We can only perform 1 action on each step
|
||||
return action
|
||||
|
||||
# If no good actions, just go with an action that wont do any harm
|
||||
action_node_id = 1
|
||||
action_node_property = "NONE"
|
||||
property_action = "ON"
|
||||
action_service_index = 0
|
||||
action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
action = transform_action_node_enum(action)
|
||||
action = get_new_action(action, action_dict)
|
||||
|
||||
return action
|
||||
@@ -1,287 +0,0 @@
|
||||
# # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
# from __future__ import annotations
|
||||
|
||||
# import json
|
||||
# import shutil
|
||||
# import zipfile
|
||||
# from datetime import datetime
|
||||
# from logging import Logger
|
||||
# from pathlib import Path
|
||||
# from typing import Any, Callable, Dict, Optional, Union
|
||||
# from uuid import uuid4
|
||||
|
||||
# from primaite import getLogger
|
||||
# from primaite.agents.agent_abc import AgentSessionABC
|
||||
# from primaite.common.enums import AgentFramework, AgentIdentifier, SessionType
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
|
||||
# # from ray.rllib.algorithms import Algorithm
|
||||
# # from ray.rllib.algorithms.a2c import A2CConfig
|
||||
# # from ray.rllib.algorithms.ppo import PPOConfig
|
||||
# # from ray.tune.logger import UnifiedLogger
|
||||
# # from ray.tune.registry import register_env
|
||||
|
||||
|
||||
# # from primaite.exceptions import RLlibAgentError
|
||||
|
||||
# _LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
# # TODO: verify type of env_config
|
||||
# def _env_creator(env_config: Dict[str, Any]) -> Primaite:
|
||||
# return Primaite(
|
||||
# training_config_path=env_config["training_config_path"],
|
||||
# lay_down_config_path=env_config["lay_down_config_path"],
|
||||
# session_path=env_config["session_path"],
|
||||
# timestamp_str=env_config["timestamp_str"],
|
||||
# )
|
||||
|
||||
# # # TODO: verify type hint return type
|
||||
# # def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
# # logdir = session_path / "ray_results"
|
||||
# # logdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# # def logger_creator(config: Dict) -> UnifiedLogger:
|
||||
# # return UnifiedLogger(config, logdir, loggers=None)
|
||||
|
||||
# return logger_creator
|
||||
|
||||
|
||||
# # class RLlibAgent(AgentSessionABC):
|
||||
# # """An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
# # def __init__(
|
||||
# # self,
|
||||
# # training_config_path: Optional[Union[str, Path]] = "",
|
||||
# # lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
# # session_path: Optional[Union[str, Path]] = None,
|
||||
# # ) -> None:
|
||||
# # """
|
||||
# # Initialise the RLLib Agent training session.
|
||||
|
||||
# # :param training_config_path: YAML file containing configurable items defined in
|
||||
# # `primaite.config.training_config.TrainingConfig`
|
||||
# # :type training_config_path: Union[path, str]
|
||||
# # :param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
# # :type lay_down_config_path: Union[path, str]
|
||||
# # :raises ValueError: If the training config contains a bad value for agent_framework (should be "RLLIB")
|
||||
# # :raises ValueError: If the training config contains a bad value for agent_identifies (should be `PPO`
|
||||
# # or `A2C`)
|
||||
# # """
|
||||
# # # TODO: implement RLlib agent loading
|
||||
# # if session_path is not None:
|
||||
# # msg = "RLlib agent loading has not been implemented yet"
|
||||
# # _LOGGER.critical(msg)
|
||||
# # raise NotImplementedError(msg)
|
||||
|
||||
# # super().__init__(training_config_path, lay_down_config_path)
|
||||
# # if self._training_config.session_type == SessionType.EVAL:
|
||||
# # msg = "Cannot evaluate an RLlib agent that hasn't been through training yet."
|
||||
# # _LOGGER.critical(msg)
|
||||
# # raise RLlibAgentError(msg)
|
||||
# # if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
# # msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
# # _LOGGER.error(msg)
|
||||
# # raise ValueError(msg)
|
||||
# # self._agent_config_class: Union[PPOConfig, A2CConfig]
|
||||
# # if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
# # self._agent_config_class = PPOConfig
|
||||
# # elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
# # self._agent_config_class = A2CConfig
|
||||
# # else:
|
||||
# # msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}"
|
||||
# # _LOGGER.error(msg)
|
||||
# # raise ValueError(msg)
|
||||
# # self._agent_config: Union[PPOConfig, A2CConfig]
|
||||
|
||||
# # self._current_result: dict
|
||||
# # self._setup()
|
||||
# # _LOGGER.debug(
|
||||
# # f"Created {self.__class__.__name__} using: "
|
||||
# # f"agent_framework={self._training_config.agent_framework}, "
|
||||
# # f"agent_identifier="
|
||||
# # f"{self._training_config.agent_identifier}, "
|
||||
# # f"deep_learning_framework="
|
||||
# # f"{self._training_config.deep_learning_framework}"
|
||||
# # )
|
||||
# # self._train_agent = None # Required to capture the learning agent to close after eval
|
||||
|
||||
# # def _update_session_metadata_file(self) -> None:
|
||||
# # """
|
||||
# # Update the ``session_metadata.json`` file.
|
||||
|
||||
# # Updates the `session_metadata.json`` in the ``session_path`` directory
|
||||
# # with the following key/value pairs:
|
||||
|
||||
# # - end_datetime: The date & time the session ended in iso format.
|
||||
# # - total_episodes: The total number of training episodes completed.
|
||||
# # - total_time_steps: The total number of training time steps completed.
|
||||
# # """
|
||||
# # with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
# # metadata_dict = json.load(file)
|
||||
|
||||
# # metadata_dict["end_datetime"] = datetime.now().isoformat()
|
||||
# # if not self.is_eval:
|
||||
# # metadata_dict["learning"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
# # metadata_dict["learning"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
# # else:
|
||||
# # metadata_dict["evaluation"]["total_episodes"] = self._current_result["episodes_total"] # noqa
|
||||
# # metadata_dict["evaluation"]["total_time_steps"] = self._current_result["timesteps_total"] # noqa
|
||||
|
||||
# # filepath = self.session_path / "session_metadata.json"
|
||||
# # _LOGGER.debug(f"Updating Session Metadata file: {filepath}")
|
||||
# # with open(filepath, "w") as file:
|
||||
# # json.dump(metadata_dict, file)
|
||||
# # _LOGGER.debug("Finished updating session metadata file")
|
||||
|
||||
# # def _setup(self) -> None:
|
||||
# # super()._setup()
|
||||
# # register_env("primaite", _env_creator)
|
||||
# # self._agent_config = self._agent_config_class()
|
||||
|
||||
# # self._agent_config.environment(
|
||||
# # env="primaite",
|
||||
# # env_config=dict(
|
||||
# # training_config_path=self._training_config_path,
|
||||
# # lay_down_config_path=self._lay_down_config_path,
|
||||
# # session_path=self.session_path,
|
||||
# # timestamp_str=self.timestamp_str,
|
||||
# # ),
|
||||
# # )
|
||||
# # self._agent_config.seed = self._training_config.seed
|
||||
|
||||
# # self._agent_config.training(train_batch_size=self._training_config.num_train_steps)
|
||||
# # self._agent_config.framework(framework="tf")
|
||||
|
||||
# # self._agent_config.rollouts(
|
||||
# # num_rollout_workers=1,
|
||||
# # num_envs_per_worker=1,
|
||||
# # horizon=self._training_config.num_train_steps,
|
||||
# # )
|
||||
# # self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path))
|
||||
|
||||
# # def _save_checkpoint(self) -> None:
|
||||
# # checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
# # episode_count = self._current_result["episodes_total"]
|
||||
# # save_checkpoint = False
|
||||
# # if checkpoint_n:
|
||||
# # save_checkpoint = episode_count % checkpoint_n == 0
|
||||
# # if episode_count and save_checkpoint:
|
||||
# # self._agent.save(str(self.checkpoints_path))
|
||||
|
||||
# # def learn(
|
||||
# # self,
|
||||
# # **kwargs: Any,
|
||||
# # ) -> None:
|
||||
# # """
|
||||
# # Evaluate the agent.
|
||||
|
||||
# # :param kwargs: Any agent-specific key-word args to be passed.
|
||||
# # """
|
||||
# # time_steps = self._training_config.num_train_steps
|
||||
# # episodes = self._training_config.num_train_episodes
|
||||
|
||||
# # _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
# # for i in range(episodes):
|
||||
# # self._current_result = self._agent.train()
|
||||
# # self._save_checkpoint()
|
||||
# # self.save()
|
||||
# # super().learn()
|
||||
# # # Done this way as the RLlib eval can only be performed if the session hasn't been stopped
|
||||
# # if self._training_config.session_type is not SessionType.TRAIN:
|
||||
# # self._train_agent = self._agent
|
||||
# # else:
|
||||
# # self._agent.stop()
|
||||
# # self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
# # def _unpack_saved_agent_into_eval(self) -> Path:
|
||||
# # """Unpacks the pre-trained and saved RLlib agent so that it can be reloaded by Ray for eval."""
|
||||
# # agent_restore_path = self.evaluation_path / "agent_restore"
|
||||
# # if agent_restore_path.exists():
|
||||
# # shutil.rmtree(agent_restore_path)
|
||||
# # agent_restore_path.mkdir()
|
||||
# # with zipfile.ZipFile(self._saved_agent_path, "r") as zip_file:
|
||||
# # zip_file.extractall(agent_restore_path)
|
||||
# # return agent_restore_path
|
||||
|
||||
# # def _setup_eval(self):
|
||||
# # self._can_learn = False
|
||||
# # self._can_evaluate = True
|
||||
# # self._agent.restore(str(self._unpack_saved_agent_into_eval()))
|
||||
|
||||
# # def evaluate(
|
||||
# # self,
|
||||
# # **kwargs,
|
||||
# # ):
|
||||
# # """
|
||||
# # Evaluate the agent.
|
||||
|
||||
# # :param kwargs: Any agent-specific key-word args to be passed.
|
||||
# # """
|
||||
# # time_steps = self._training_config.num_eval_steps
|
||||
# # episodes = self._training_config.num_eval_episodes
|
||||
|
||||
# # self._setup_eval()
|
||||
|
||||
# # self._env: Primaite = Primaite(
|
||||
# # self._training_config_path, self._lay_down_config_path, self.session_path, self.timestamp_str
|
||||
# # )
|
||||
|
||||
# # self._env.set_as_eval()
|
||||
# # self.is_eval = True
|
||||
# # if self._training_config.deterministic:
|
||||
# # deterministic_str = "deterministic"
|
||||
# # else:
|
||||
# # deterministic_str = "non-deterministic"
|
||||
# # _LOGGER.info(
|
||||
# # f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
# # )
|
||||
# # for episode in range(episodes):
|
||||
# # obs = self._env.reset()
|
||||
# # for step in range(time_steps):
|
||||
# # action = self._agent.compute_single_action(observation=obs, explore=False)
|
||||
|
||||
# # obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
# # self._env.reset()
|
||||
# # self._env.close()
|
||||
# # super().evaluate()
|
||||
# # # Now we're safe to close the learning agent and write the mean rewards per episode for it
|
||||
# # if self._training_config.session_type is not SessionType.TRAIN:
|
||||
# # self._train_agent.stop()
|
||||
# # self._plot_av_reward_per_episode(learning_session=True)
|
||||
# # # Perform a clean-up of the unpacked agent
|
||||
# # if (self.evaluation_path / "agent_restore").exists():
|
||||
# # shutil.rmtree((self.evaluation_path / "agent_restore"))
|
||||
|
||||
# # def _get_latest_checkpoint(self) -> None:
|
||||
# # raise NotImplementedError
|
||||
|
||||
# # @classmethod
|
||||
# # def load(cls, path: Union[str, Path]) -> RLlibAgent:
|
||||
# # """Load an agent from file."""
|
||||
# # raise NotImplementedError
|
||||
|
||||
# # def save(self, overwrite_existing: bool = True) -> None:
|
||||
# # """Save the agent."""
|
||||
# # # Make temp dir to save in isolation
|
||||
# # temp_dir = self.learning_path / str(uuid4())
|
||||
# # temp_dir.mkdir()
|
||||
|
||||
# # # Save the agent to the temp dir
|
||||
# # self._agent.save(str(temp_dir))
|
||||
|
||||
# # # Capture the saved Rllib checkpoint inside the temp directory
|
||||
# # for file in temp_dir.iterdir():
|
||||
# # checkpoint_dir = file
|
||||
# # break
|
||||
|
||||
# # # Zip the folder
|
||||
# # shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
|
||||
|
||||
# # # Drop the temp directory
|
||||
# # shutil.rmtree(temp_dir)
|
||||
|
||||
# # def export(self) -> None:
|
||||
# # """Export the agent to transportable file format."""
|
||||
# # raise NotImplementedError
|
||||
@@ -1,206 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
self._agent_class: Union[PPO, A2C]
|
||||
if self._training_config.agent_identifier == AgentIdentifier.PPO:
|
||||
self._agent_class = PPO
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.A2C:
|
||||
self._agent_class = A2C
|
||||
else:
|
||||
msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
f"agent_identifier="
|
||||
f"{self._training_config.agent_identifier}"
|
||||
)
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
legacy_training_config=self.legacy_training_config,
|
||||
legacy_lay_down_config=self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
# check if there is a zip file that needs to be loaded
|
||||
load_file = next(self.session_path.rglob("*.zip"), None)
|
||||
|
||||
if not load_file:
|
||||
# create a new env and agent
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
else:
|
||||
# set env values from session metadata
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# load environment values
|
||||
if self.is_eval:
|
||||
# evaluation always starts at 0
|
||||
self._env.episode_count = 0
|
||||
self._env.total_step_count = 0
|
||||
else:
|
||||
# carry on from previous learning sessions
|
||||
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
||||
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
||||
|
||||
# load the file
|
||||
self._agent = self._agent_class.load(load_file, env=self._env)
|
||||
|
||||
# set agent values
|
||||
self._agent.verbose = self.sb3_output_verbose_level
|
||||
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._env.episode_count
|
||||
save_checkpoint = False
|
||||
if checkpoint_n:
|
||||
save_checkpoint = episode_count % checkpoint_n == 0
|
||||
if episode_count and save_checkpoint:
|
||||
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
|
||||
self._agent.save(checkpoint_path)
|
||||
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_train_steps
|
||||
episodes = self._training_config.num_train_episodes
|
||||
self.is_eval = False
|
||||
_LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...")
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
self._save_checkpoint()
|
||||
self._env._write_av_reward_per_episode() # noqa
|
||||
self.save()
|
||||
self._env.close()
|
||||
super().learn()
|
||||
|
||||
# save agent
|
||||
self.save()
|
||||
|
||||
self._plot_av_reward_per_episode(learning_session=True)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
self._env.set_as_eval()
|
||||
self.is_eval = True
|
||||
if self._training_config.deterministic:
|
||||
deterministic_str = "deterministic"
|
||||
else:
|
||||
deterministic_str = "non-deterministic"
|
||||
_LOGGER.info(
|
||||
f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..."
|
||||
)
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = np.int64(action)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
self._env._write_av_reward_per_episode() # noqa
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
raise NotImplementedError
|
||||
@@ -1,59 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
|
||||
class RandomAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Random Agent.
|
||||
|
||||
Get a completely random action from the action space.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return self._env.action_space.sample()
|
||||
|
||||
|
||||
class DummyAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A Dummy Agent.
|
||||
|
||||
All action spaces setup so dummy action is always 0 regardless of action type used.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class DoNothingACLAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing ACL agent.
|
||||
|
||||
A valid ACL action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"]
|
||||
nothing_action = transform_action_acl_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
|
||||
return nothing_action
|
||||
|
||||
|
||||
class DoNothingNodeAgent(HardCodedAgentSessionABC):
|
||||
"""
|
||||
A do nothing Node agent.
|
||||
|
||||
A valid Node action that has no effect; does nothing.
|
||||
"""
|
||||
|
||||
def _calculate_action(self, obs: np.ndarray) -> int:
|
||||
nothing_action = [1, "NONE", "ON", 0]
|
||||
nothing_action = transform_action_node_enum(nothing_action)
|
||||
nothing_action = get_new_action(nothing_action, self._env.action_dict)
|
||||
# nothing_action should currently always be 0
|
||||
|
||||
return nothing_action
|
||||
@@ -1,450 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import (
|
||||
HardwareState,
|
||||
LinkStatus,
|
||||
NodeHardwareAction,
|
||||
NodePOLType,
|
||||
NodeSoftwareAction,
|
||||
SoftwareState,
|
||||
)
|
||||
|
||||
|
||||
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
|
||||
"""Convert a node action from enumerated format to readable format.
|
||||
|
||||
example:
|
||||
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_node_property = NodePOLType(action[1]).name
|
||||
|
||||
if action_node_property == "OPERATING":
|
||||
property_action = NodeHardwareAction(action[2]).name
|
||||
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
|
||||
property_action = NodeSoftwareAction(action[2]).name
|
||||
else:
|
||||
property_action = "NONE"
|
||||
|
||||
new_action: list[Union[int, str]] = [action[0], action_node_property, property_action, action[3]]
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_readable(action: List[int]) -> List[Union[str, int]]:
|
||||
"""
|
||||
Transform an ACL action to a more readable format.
|
||||
|
||||
example:
|
||||
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: The same action list, but with the encodings translated back into meaningful labels
|
||||
:rtype: List[Union[int,str]]
|
||||
"""
|
||||
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
|
||||
action_permissions = {0: "DENY", 1: "ALLOW"}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == 0:
|
||||
new_action[n + 2] = "ANY"
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def is_valid_node_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the node action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect
|
||||
|
||||
Does NOT consider:
|
||||
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
|
||||
- Node already being in that state (turning an ON node ON)
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_node_readable(action)
|
||||
|
||||
node_property = action_r[1]
|
||||
node_action = action_r[2]
|
||||
|
||||
# print("node property", node_property, "\nnode action", node_action)
|
||||
|
||||
if node_property == "NONE":
|
||||
return False
|
||||
if node_action == "NONE":
|
||||
return False
|
||||
if node_property == "OPERATING" and node_action == "PATCHING":
|
||||
# Operating State cannot PATCH
|
||||
return False
|
||||
if node_property != "OPERATING" and node_action not in [
|
||||
"NONE",
|
||||
"PATCHING",
|
||||
]:
|
||||
# Software States can only do Nothing or Patch
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action(action: List[int]) -> bool:
|
||||
"""
|
||||
Is the ACL action an actual valid action.
|
||||
|
||||
Only uses information about the action to determine if the action has an effect.
|
||||
|
||||
Does NOT consider:
|
||||
- Trying to create identical rules
|
||||
- Trying to create a rule which is a subset of another rule (caused by "ANY")
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
action_r = transform_action_acl_readable(action)
|
||||
|
||||
action_decision = action_r[0]
|
||||
action_permission = action_r[1]
|
||||
action_source_id = action_r[2]
|
||||
action_destination_id = action_r[3]
|
||||
|
||||
if action_decision == "NONE":
|
||||
return False
|
||||
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
|
||||
# ACL rule towards itself
|
||||
return False
|
||||
if action_permission == "DENY":
|
||||
# DENY is unnecessary, we can create and delete allow rules instead
|
||||
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_acl_action_extra(action: List[int]) -> bool:
|
||||
"""
|
||||
Harsher version of valid acl actions, does not allow action.
|
||||
|
||||
:param action: Agent action, formatted as a list of ints, for more information check out
|
||||
`primaite.environment.primaite_env.Primaite`
|
||||
:type action: List[int]
|
||||
:return: Whether the action is valid
|
||||
:rtype: bool
|
||||
"""
|
||||
if is_valid_acl_action(action) is False:
|
||||
return False
|
||||
|
||||
action_r = transform_action_acl_readable(action)
|
||||
action_protocol = action_r[4]
|
||||
action_port = action_r[5]
|
||||
|
||||
# Don't allow protocols or ports to be ANY
|
||||
# in the future we might want to do the opposite, and only have ANY option for ports and service
|
||||
if action_protocol == "ANY":
|
||||
return False
|
||||
if action_port == "ANY":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform list of transactions to readable list of each observation property.
|
||||
|
||||
example:
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
|
||||
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
ids = [i for i in obs[:, 0]]
|
||||
operating_states = [HardwareState(i).name for i in obs[:, 1]]
|
||||
os_states = [SoftwareState(i).name for i in obs[:, 2]]
|
||||
new_obs = [ids, operating_states, os_states]
|
||||
|
||||
for service in range(4, obs.shape[1]):
|
||||
# Links bit/s don't have a service state
|
||||
service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]]
|
||||
new_obs.append(service_states)
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
|
||||
"""Transform observation to readable format.
|
||||
|
||||
example
|
||||
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
|
||||
|
||||
:param obs: Raw observation from the environment.
|
||||
:type obs: np.ndarray
|
||||
:return: The same observation, but the encoded integer values are replaced with readable names.
|
||||
:rtype: List[List[Union[str, int]]]
|
||||
"""
|
||||
changed_obs = transform_change_obs_readable(obs)
|
||||
new_obs = list(zip(*changed_obs))
|
||||
# Convert list of tuples to list of lists
|
||||
new_obs = [list(i) for i in new_obs]
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
|
||||
"""Convert original gym Box observation space to new multiDiscrete observation space.
|
||||
|
||||
:param obs: observation in the 'old' (NodeLinkTable) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:return: reformatted observation
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
|
||||
new_obs = obs[:num_nodes, 1:].flatten()
|
||||
return new_obs
|
||||
|
||||
|
||||
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
|
||||
"""Convert to old observation.
|
||||
|
||||
Links filled with 0's as no information is included in new observation space.
|
||||
|
||||
example:
|
||||
obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1])
|
||||
|
||||
new_obs = array([[ 1, 1, 1, 1],
|
||||
[ 2, 1, 1, 1],
|
||||
[ 3, 1, 1, 1],
|
||||
...
|
||||
[20, 0, 0, 0]])
|
||||
|
||||
:param obs: observation in the 'new' (MultiDiscrete) format
|
||||
:type obs: np.ndarray
|
||||
:param num_nodes: number of nodes in the network, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: number of links in the network, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: number of services on the network, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: 2-d BOX observation space, in the same format as NodeLinkTable
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
# Convert back to more readable, original format
|
||||
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
|
||||
|
||||
# Add empty links back and add node ID back
|
||||
s = np.zeros(
|
||||
[reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1],
|
||||
dtype=np.int64,
|
||||
)
|
||||
s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back
|
||||
s[:num_nodes, 1:] = reshaped_nodes # put values back in
|
||||
new_obs = s
|
||||
|
||||
# Add links back in
|
||||
links = obs[-num_links:]
|
||||
# Links will be added to the last protocol/service slot but they are not specific to that service
|
||||
new_obs[num_nodes:, -1] = links
|
||||
|
||||
return new_obs
|
||||
|
||||
|
||||
def describe_obs_change(
|
||||
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
|
||||
) -> str:
|
||||
"""Build a string describing the difference between two observations.
|
||||
|
||||
example:
|
||||
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
|
||||
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
|
||||
output = 'ID 1: SERVICE 2 set to GOOD'
|
||||
|
||||
:param obs1: First observation
|
||||
:type obs1: np.ndarray
|
||||
:param obs2: Second observation
|
||||
:type obs2: np.ndarray
|
||||
:param num_nodes: How many nodes are in the network laydown, defaults to 10
|
||||
:type num_nodes: int, optional
|
||||
:param num_links: How many links are in the network laydown, defaults to 10
|
||||
:type num_links: int, optional
|
||||
:param num_services: How many services are configured for this scenario, defaults to 1
|
||||
:type num_services: int, optional
|
||||
:return: A multi-line string with a human-readable description of the difference.
|
||||
:rtype: str
|
||||
"""
|
||||
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
|
||||
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
|
||||
list_of_changes = []
|
||||
for n, row in enumerate(obs1 - obs2):
|
||||
if row.any() != 0:
|
||||
relevant_changes = np.where(row != 0, obs2[n], -1)
|
||||
relevant_changes[0] = obs2[n, 0] # ID is always relevant
|
||||
is_link = relevant_changes[0] > num_nodes
|
||||
desc = _describe_obs_change_helper(relevant_changes, is_link)
|
||||
list_of_changes.append(desc)
|
||||
|
||||
change_string = "\n ".join(list_of_changes)
|
||||
if len(list_of_changes) > 0:
|
||||
change_string = "\n " + change_string
|
||||
return change_string
|
||||
|
||||
|
||||
def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str:
|
||||
"""
|
||||
Helper funcion to describe what has changed.
|
||||
|
||||
example:
|
||||
[ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD"
|
||||
|
||||
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
|
||||
|
||||
:param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one
|
||||
row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new
|
||||
status where it has changed.
|
||||
:type obs_change: List[int]
|
||||
:param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node.
|
||||
:type is_link: bool
|
||||
:return: A human-readable description of the difference between the two observation rows.
|
||||
:rtype: str
|
||||
"""
|
||||
# Indexes where a change has occured, not including 0th index
|
||||
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
|
||||
# Node pol types, Indexes >= 3 are service nodes
|
||||
NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed]
|
||||
# Account for hardware states, software sattes and links
|
||||
states = [
|
||||
LinkStatus(obs_change[i]).name
|
||||
if is_link
|
||||
else HardwareState(obs_change[i]).name
|
||||
if i == 1
|
||||
else SoftwareState(obs_change[i]).name
|
||||
for i in index_changed
|
||||
]
|
||||
|
||||
if not is_link:
|
||||
desc = f"ID {obs_change[0]}:"
|
||||
for node_pol_type, state in list(zip(NodePOLTypes, states)):
|
||||
desc = desc + " " + node_pol_type + " changed to " + state + "."
|
||||
else:
|
||||
desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}."
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]:
|
||||
"""Convert a node action from readable string format, to enumerated format.
|
||||
|
||||
example:
|
||||
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
|
||||
:param action: Action in 'readable' format
|
||||
:type action: List[Union[str,int]]
|
||||
:return: Action with verbs encoded as ints
|
||||
:rtype: List[int]
|
||||
"""
|
||||
action_node_id = action[0]
|
||||
action_node_property = NodePOLType[action[1]].value
|
||||
|
||||
if action[1] == "OPERATING":
|
||||
property_action = NodeHardwareAction[action[2]].value
|
||||
elif action[1] == "OS" or action[1] == "SERVICE":
|
||||
property_action = NodeSoftwareAction[action[2]].value
|
||||
else:
|
||||
property_action = 0
|
||||
|
||||
action_service_index = action[3]
|
||||
|
||||
new_action = [
|
||||
action_node_id,
|
||||
action_node_property,
|
||||
property_action,
|
||||
action_service_index,
|
||||
]
|
||||
|
||||
return new_action
|
||||
|
||||
|
||||
def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray:
|
||||
"""
|
||||
Convert acl action from readable str format, to enumerated format.
|
||||
|
||||
:param action: ACL-based action expressed as a list of human-readable ints and strings
|
||||
:type action: List[Union[int,str]]
|
||||
:return: The same action but encoded to contain only integers.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
|
||||
action_permissions = {"DENY": 0, "ALLOW": 1}
|
||||
|
||||
action_decision = action_decisions[action[0]]
|
||||
action_permission = action_permissions[action[1]]
|
||||
|
||||
# For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index
|
||||
new_action = [action_decision, action_permission] + list(action[2:6])
|
||||
for n, val in enumerate(list(action[2:6])):
|
||||
if val == "ANY":
|
||||
new_action[n + 2] = 0
|
||||
|
||||
new_action = np.array(new_action)
|
||||
return new_action
|
||||
|
||||
|
||||
def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str:
|
||||
"""Get the node ID of an IP address.
|
||||
|
||||
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
|
||||
|
||||
:param ip: The IP address of the node whose ID is required
|
||||
:type ip: str
|
||||
:param node_dict: The environment's node registry dictionary
|
||||
:type node_dict: Dict[str,NodeUnion]
|
||||
:return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip`
|
||||
:rtype: str
|
||||
"""
|
||||
for node_key, node_value in node_dict.items():
|
||||
node_ip = node_value.ip_address
|
||||
if node_ip == ip:
|
||||
return node_key
|
||||
|
||||
|
||||
def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int:
|
||||
"""
|
||||
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
|
||||
|
||||
Old_action can be either node or acl action type
|
||||
|
||||
:param old_action: Action expressed as a list of choices, eg. [1,1,1,0]
|
||||
:type old_action: np.ndarray
|
||||
:param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions.
|
||||
:type action_dict: Dict[int,List]
|
||||
:return: Action key correspoinding to the input `old_action`
|
||||
:rtype: int
|
||||
"""
|
||||
for key, val in action_dict.items():
|
||||
if list(val) == list(old_action):
|
||||
return key
|
||||
# Not all possible actions are included in dict, only valid action are
|
||||
# if action is not in the dict, its an invalid action so return 0
|
||||
return 0
|
||||
@@ -10,7 +10,6 @@ import yaml
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.data_viz import PlotlyTemplate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -30,7 +29,7 @@ def reset_notebooks(overwrite: bool = True) -> None:
|
||||
|
||||
:param overwrite: If True, will overwrite existing demo notebooks.
|
||||
"""
|
||||
from primaite.setup import reset_demo_notebooks
|
||||
from src.primaite.setup import reset_demo_notebooks
|
||||
|
||||
reset_demo_notebooks.run(overwrite)
|
||||
|
||||
@@ -81,14 +80,6 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None) ->
|
||||
print(f"PrimAITE Log Level: {level}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def notebooks() -> None:
|
||||
"""Start Jupyter Lab in the users PrimAITE notebooks directory."""
|
||||
from primaite.notebooks import start_jupyter_session
|
||||
|
||||
start_jupyter_session()
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
"""Get the installed PrimAITE version number."""
|
||||
@@ -97,14 +88,6 @@ def version() -> None:
|
||||
print(primaite.__version__)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clean_up() -> None:
|
||||
"""Cleans up left over files from previous version installations."""
|
||||
from primaite.setup import old_installation_clean_up
|
||||
|
||||
old_installation_clean_up.run()
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
"""
|
||||
@@ -112,8 +95,10 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from arcd_gate.cli import setup as gate_setup
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs
|
||||
from src.primaite.setup import reset_demo_notebooks, reset_example_configs
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -130,84 +115,32 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Performing a clean-up of previous PrimAITE installations...")
|
||||
old_installation_clean_up.run()
|
||||
_LOGGER.info("Setting up ARCD GATE...")
|
||||
gate_setup()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
tc: Optional[str] = None,
|
||||
ldc: Optional[str] = None,
|
||||
load: Optional[str] = None,
|
||||
legacy_tc: bool = False,
|
||||
legacy_ldc: bool = False,
|
||||
config: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
tc: The training config filepath. Optional. If no value is passed then
|
||||
example default training config is used from:
|
||||
~/primaite/2.0.0/config/example_config/training/training_config_main.yaml.
|
||||
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/2.0.0/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
|
||||
load: The directory of a previous session. Optional. If no value is passed, then the session
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
|
||||
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
|
||||
|
||||
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
from threading import Thread
|
||||
|
||||
if load is not None:
|
||||
# run a loaded session
|
||||
run(session_path=load)
|
||||
from src.primaite.config.load import example_config_path
|
||||
from src.primaite.main import run
|
||||
from src.primaite.utils.start_gate_server import start_gate_server
|
||||
|
||||
else:
|
||||
# start a new session using tc and ldc
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
server_thread = Thread(target=start_gate_server)
|
||||
server_thread.start()
|
||||
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(
|
||||
training_config_path=tc,
|
||||
lay_down_config_path=ldc,
|
||||
legacy_training_config=legacy_tc,
|
||||
legacy_lay_down_config=legacy_ldc,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None) -> None:
|
||||
"""
|
||||
View or set the plotly template for Session plots.
|
||||
|
||||
To View, simply call: primaite plotly-template
|
||||
|
||||
To set, call: primaite plotly-template <desired template>
|
||||
|
||||
For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
|
||||
if template:
|
||||
primaite_config["session"]["outputs"]["plots"]["template"] = template.value
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print(f"PrimAITE plotly template: {template.value}")
|
||||
else:
|
||||
template = primaite_config["session"]["outputs"]["plots"]["template"]
|
||||
print(f"PrimAITE plotly template: {template}")
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
print(config)
|
||||
run(config_path=config)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Objects which are shared between many PrimAITE modules."""
|
||||
@@ -1,8 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.passive_node import PassiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
NodeUnion = Union[ActiveNode, PassiveNode, ServiceNode]
|
||||
"""A Union of ActiveNode, PassiveNode, and ServiceNode."""
|
||||
@@ -1,208 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""Node type enumeration."""
|
||||
|
||||
CCTV = 1
|
||||
SWITCH = 2
|
||||
COMPUTER = 3
|
||||
LINK = 4
|
||||
MONITOR = 5
|
||||
PRINTER = 6
|
||||
LOP = 7
|
||||
RTU = 8
|
||||
ACTUATOR = 9
|
||||
SERVER = 10
|
||||
|
||||
|
||||
class Priority(Enum):
|
||||
"""Node priority enumeration."""
|
||||
|
||||
P1 = 1
|
||||
P2 = 2
|
||||
P3 = 3
|
||||
P4 = 4
|
||||
P5 = 5
|
||||
|
||||
|
||||
class HardwareState(Enum):
|
||||
"""Node hardware state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESETTING = 3
|
||||
SHUTTING_DOWN = 4
|
||||
BOOTING = 5
|
||||
|
||||
|
||||
class SoftwareState(Enum):
|
||||
"""Software or Service state enumeration."""
|
||||
|
||||
NONE = 0
|
||||
GOOD = 1
|
||||
PATCHING = 2
|
||||
COMPROMISED = 3
|
||||
OVERWHELMED = 4
|
||||
|
||||
|
||||
class NodePOLType(Enum):
|
||||
"""Node Pattern of Life type enumeration."""
|
||||
|
||||
NONE = 0
|
||||
OPERATING = 1
|
||||
OS = 2
|
||||
SERVICE = 3
|
||||
FILE = 4
|
||||
|
||||
|
||||
class NodePOLInitiator(Enum):
|
||||
"""Node Pattern of Life initiator enumeration."""
|
||||
|
||||
DIRECT = 1
|
||||
IER = 2
|
||||
SERVICE = 3
|
||||
|
||||
|
||||
class Protocol(Enum):
|
||||
"""Service protocol enumeration."""
|
||||
|
||||
LDAP = 0
|
||||
FTP = 1
|
||||
HTTPS = 2
|
||||
SMTP = 3
|
||||
RTP = 4
|
||||
IPP = 5
|
||||
TCP = 6
|
||||
NONE = 7
|
||||
|
||||
|
||||
class SessionType(Enum):
|
||||
"""The type of PrimAITE Session to be run."""
|
||||
|
||||
TRAIN = 1
|
||||
"Train an agent"
|
||||
EVAL = 2
|
||||
"Evaluate an agent"
|
||||
TRAIN_EVAL = 3
|
||||
"Train then evaluate an agent"
|
||||
|
||||
|
||||
class AgentFramework(Enum):
|
||||
"""The agent algorithm framework/package."""
|
||||
|
||||
CUSTOM = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
# RLLIB = 2
|
||||
# "Ray RLlib"
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework."""
|
||||
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
"Tensorflow 2.x"
|
||||
TORCH = "torch"
|
||||
"PyTorch"
|
||||
|
||||
|
||||
class AgentIdentifier(Enum):
|
||||
"""The Red Agent algo/class."""
|
||||
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"The Hardcoded agents"
|
||||
DO_NOTHING = 4
|
||||
"The DoNothing agents"
|
||||
RANDOM = 5
|
||||
"The RandomAgent"
|
||||
DUMMY = 6
|
||||
"The DummyAgent"
|
||||
|
||||
|
||||
class HardCodedAgentView(Enum):
|
||||
"""The view the deterministic hard-coded agent has of the environment."""
|
||||
|
||||
BASIC = 1
|
||||
"The current observation space only"
|
||||
FULL = 2
|
||||
"Full environment view with actions taken and reward feedback"
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Action type enumeration."""
|
||||
|
||||
NODE = 0
|
||||
ACL = 1
|
||||
ANY = 2
|
||||
|
||||
|
||||
# TODO: this is not used anymore, write a ticket to delete it.
|
||||
class ObservationType(Enum):
|
||||
"""Observation type enumeration."""
|
||||
|
||||
BOX = 0
|
||||
MULTIDISCRETE = 1
|
||||
|
||||
|
||||
class FileSystemState(Enum):
|
||||
"""File System State."""
|
||||
|
||||
GOOD = 1
|
||||
CORRUPT = 2
|
||||
DESTROYED = 3
|
||||
REPAIRING = 4
|
||||
RESTORING = 5
|
||||
|
||||
|
||||
class NodeHardwareAction(Enum):
|
||||
"""Node hardware action."""
|
||||
|
||||
NONE = 0
|
||||
ON = 1
|
||||
OFF = 2
|
||||
RESET = 3
|
||||
|
||||
|
||||
class NodeSoftwareAction(Enum):
|
||||
"""Node software action."""
|
||||
|
||||
NONE = 0
|
||||
PATCHING = 1
|
||||
|
||||
|
||||
class LinkStatus(Enum):
|
||||
"""Link traffic status."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
OVERLOAD = 4
|
||||
|
||||
|
||||
class SB3OutputVerboseLevel(IntEnum):
|
||||
"""The Stable Baselines3 learn/eval output verbosity level."""
|
||||
|
||||
NONE = 0
|
||||
INFO = 1
|
||||
DEBUG = 2
|
||||
|
||||
|
||||
class RulePermissionType(Enum):
|
||||
"""Any firewall rule type."""
|
||||
|
||||
NONE = 0
|
||||
DENY = 1
|
||||
ALLOW = 2
|
||||
@@ -1,47 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
class Protocol(object):
|
||||
"""Protocol class."""
|
||||
|
||||
def __init__(self, _name: str) -> None:
|
||||
"""
|
||||
Initialise a protocol.
|
||||
|
||||
:param _name: The name of the protocol
|
||||
:type _name: str
|
||||
"""
|
||||
self.name: str = _name
|
||||
self.load: int = 0 # bps
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Gets the protocol name.
|
||||
|
||||
Returns:
|
||||
The protocol name
|
||||
"""
|
||||
return self.name
|
||||
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets the protocol load.
|
||||
|
||||
Returns:
|
||||
The protocol load (bps)
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def add_load(self, _load: int) -> None:
|
||||
"""
|
||||
Adds load to the protocol.
|
||||
|
||||
Args:
|
||||
_load: The load to add
|
||||
"""
|
||||
self.load += _load
|
||||
|
||||
def clear_load(self) -> None:
|
||||
"""Clears the load on this protocol."""
|
||||
self.load = 0
|
||||
@@ -1,28 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SoftwareState
|
||||
|
||||
|
||||
class Service(object):
|
||||
"""Service class."""
|
||||
|
||||
def __init__(self, name: str, port: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Initialise a service.
|
||||
|
||||
:param name: The service name.
|
||||
:param port: The service port.
|
||||
:param software_state: The service SoftwareState.
|
||||
"""
|
||||
self.name: str = name
|
||||
self.port: str = port
|
||||
self.software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
|
||||
def reduce_patching_count(self) -> None:
|
||||
"""Reduces the patching count for the service."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self.software_state = SoftwareState.GOOD
|
||||
726
src/primaite/config/_package_data/example_config.yaml
Normal file
726
src/primaite/config/_package_data/example_config.yaml
Normal file
@@ -0,0 +1,726 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 20
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 20
|
||||
n_eval_steps: 128
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: GATERLAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
@@ -1,166 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.5
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: SWITCH2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.6
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH3
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.7
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: LINK
|
||||
id: '7'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '4'
|
||||
- item_type: LINK
|
||||
id: '8'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '2'
|
||||
- item_type: LINK
|
||||
id: '9'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '10'
|
||||
name: link4
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: link5
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: link6
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '3'
|
||||
- item_type: GREEN_IER
|
||||
id: '13'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '3'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 50
|
||||
end_step: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '15'
|
||||
start_step: 60
|
||||
end_step: 100
|
||||
load: 1000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '2'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '16'
|
||||
start_step: 80
|
||||
end_step: 80
|
||||
targetNodeId: '2'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: ACL_RULE
|
||||
id: '17'
|
||||
permission: ALLOW
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
@@ -1,366 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: PC3
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.13
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: PC4
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.20.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: IDS
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '7'
|
||||
name: SWITCH2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '8'
|
||||
name: LOP1
|
||||
node_class: SERVICE
|
||||
node_type: LOP
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '9'
|
||||
name: SERVER1
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '10'
|
||||
name: SERVER2
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.20.15
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '13'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '14'
|
||||
name: link4
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '15'
|
||||
name: link5
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '16'
|
||||
name: link6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '8'
|
||||
- item_type: LINK
|
||||
id: '17'
|
||||
name: link7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '18'
|
||||
name: link8
|
||||
bandwidth: 1000000000
|
||||
source: '8'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '19'
|
||||
name: link9
|
||||
bandwidth: 1000000000
|
||||
source: '7'
|
||||
destination: '9'
|
||||
- item_type: LINK
|
||||
id: '20'
|
||||
name: link10
|
||||
bandwidth: 1000000000
|
||||
source: '7'
|
||||
destination: '10'
|
||||
- item_type: GREEN_IER
|
||||
id: '21'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '22'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '23'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '3'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '24'
|
||||
start_step: 1
|
||||
end_step: 128
|
||||
load: 100000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '4'
|
||||
destination: '10'
|
||||
mission_criticality: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '25'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '26'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '27'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.13
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '28'
|
||||
permission: ALLOW
|
||||
source: 192.168.20.14
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '29'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.14
|
||||
destination: 192.168.10.13
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '30'
|
||||
permission: DENY
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '31'
|
||||
permission: DENY
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '32'
|
||||
permission: DENY
|
||||
source: 192.168.10.13
|
||||
destination: 192.168.20.15
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: DENY
|
||||
source: 192.168.20.14
|
||||
destination: 192.168.10.14
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 8
|
||||
- item_type: RED_POL
|
||||
id: '34'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_POL
|
||||
id: '35'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '2'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '36'
|
||||
start_step: 30
|
||||
end_step: 128
|
||||
load: 440000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_IER
|
||||
id: '37'
|
||||
start_step: 30
|
||||
end_step: 128
|
||||
load: 440000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '38'
|
||||
start_step: 30
|
||||
end_step: 30
|
||||
targetNodeId: '9'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
@@ -1,164 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: PC1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.2
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: PC2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.3
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SERVER1
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.4
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '5'
|
||||
name: link1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '6'
|
||||
name: link2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '7'
|
||||
name: link3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- item_type: GREEN_IER
|
||||
id: '8'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '4'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '9'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '4'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '10'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '4'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '11'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.2
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '12'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.3
|
||||
destination: 192.168.1.4
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '13'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.4
|
||||
destination: 192.168.1.3
|
||||
protocol: TCP
|
||||
port: 80
|
||||
position: 2
|
||||
- item_type: RED_POL
|
||||
id: '14'
|
||||
start_step: 20
|
||||
end_step: 20
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '15'
|
||||
start_step: 30
|
||||
end_step: 256
|
||||
load: 10000000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '4'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '16'
|
||||
start_step: 40
|
||||
end_step: 40
|
||||
targetNodeId: '4'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
@@ -1,546 +0,0 @@
|
||||
- item_type: PORTS
|
||||
ports_list:
|
||||
- port: '80'
|
||||
- port: '1433'
|
||||
- port: '53'
|
||||
- item_type: SERVICES
|
||||
service_list:
|
||||
- name: TCP
|
||||
- name: TCP_SQL
|
||||
- name: UDP
|
||||
- item_type: NODE
|
||||
node_id: '1'
|
||||
name: CLIENT_1
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.11
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '2'
|
||||
name: CLIENT_2
|
||||
node_class: SERVICE
|
||||
node_type: COMPUTER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '3'
|
||||
name: SWITCH_1
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.10.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '4'
|
||||
name: SECURITY_SUITE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '5'
|
||||
name: MANAGEMENT_CONSOLE
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.1.12
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '6'
|
||||
name: SWITCH_2
|
||||
node_class: ACTIVE
|
||||
node_type: SWITCH
|
||||
priority: P2
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.1
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '7'
|
||||
name: WEB_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.10
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '8'
|
||||
name: DATABASE_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.14
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- name: TCP_SQL
|
||||
port: '1433'
|
||||
state: GOOD
|
||||
- name: UDP
|
||||
port: '53'
|
||||
state: GOOD
|
||||
- item_type: NODE
|
||||
node_id: '9'
|
||||
name: BACKUP_SERVER
|
||||
node_class: SERVICE
|
||||
node_type: SERVER
|
||||
priority: P5
|
||||
hardware_state: 'ON'
|
||||
ip_address: 192.168.2.16
|
||||
software_state: GOOD
|
||||
file_system_state: GOOD
|
||||
services:
|
||||
- name: TCP
|
||||
port: '80'
|
||||
state: GOOD
|
||||
- item_type: LINK
|
||||
id: '10'
|
||||
name: LINK_1
|
||||
bandwidth: 1000000000
|
||||
source: '1'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '11'
|
||||
name: LINK_2
|
||||
bandwidth: 1000000000
|
||||
source: '2'
|
||||
destination: '3'
|
||||
- item_type: LINK
|
||||
id: '12'
|
||||
name: LINK_3
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '4'
|
||||
- item_type: LINK
|
||||
id: '13'
|
||||
name: LINK_4
|
||||
bandwidth: 1000000000
|
||||
source: '3'
|
||||
destination: '5'
|
||||
- item_type: LINK
|
||||
id: '14'
|
||||
name: LINK_5
|
||||
bandwidth: 1000000000
|
||||
source: '4'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '15'
|
||||
name: LINK_6
|
||||
bandwidth: 1000000000
|
||||
source: '5'
|
||||
destination: '6'
|
||||
- item_type: LINK
|
||||
id: '16'
|
||||
name: LINK_7
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '7'
|
||||
- item_type: LINK
|
||||
id: '17'
|
||||
name: LINK_8
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '8'
|
||||
- item_type: LINK
|
||||
id: '18'
|
||||
name: LINK_9
|
||||
bandwidth: 1000000000
|
||||
source: '6'
|
||||
destination: '9'
|
||||
- item_type: GREEN_IER
|
||||
id: '19'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '20'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '1'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '21'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '22'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 10000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '2'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '23'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '7'
|
||||
destination: '8'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '24'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 100000
|
||||
protocol: TCP_SQL
|
||||
port: '1433'
|
||||
source: '8'
|
||||
destination: '7'
|
||||
mission_criticality: 5
|
||||
- item_type: GREEN_IER
|
||||
id: '25'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '1'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '26'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 50000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '2'
|
||||
destination: '9'
|
||||
mission_criticality: 2
|
||||
- item_type: GREEN_IER
|
||||
id: '27'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '7'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '28'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '7'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '29'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '8'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '30'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '8'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '31'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '5'
|
||||
destination: '9'
|
||||
mission_criticality: 1
|
||||
- item_type: GREEN_IER
|
||||
id: '32'
|
||||
start_step: 1
|
||||
end_step: 256
|
||||
load: 5000
|
||||
protocol: TCP
|
||||
port: '80'
|
||||
source: '9'
|
||||
destination: '5'
|
||||
mission_criticality: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '33'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 0
|
||||
- item_type: ACL_RULE
|
||||
id: '34'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 1
|
||||
- item_type: ACL_RULE
|
||||
id: '35'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 2
|
||||
- item_type: ACL_RULE
|
||||
id: '36'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 3
|
||||
- item_type: ACL_RULE
|
||||
id: '37'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.11
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 4
|
||||
- item_type: ACL_RULE
|
||||
id: '38'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.10.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 5
|
||||
- item_type: ACL_RULE
|
||||
id: '39'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 6
|
||||
- item_type: ACL_RULE
|
||||
id: '40'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 7
|
||||
- item_type: ACL_RULE
|
||||
id: '41'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.11
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 8
|
||||
- item_type: ACL_RULE
|
||||
id: '42'
|
||||
permission: ALLOW
|
||||
source: 192.168.10.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 9
|
||||
- item_type: ACL_RULE
|
||||
id: '43'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.10
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 10
|
||||
- item_type: ACL_RULE
|
||||
id: '44'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.14
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 11
|
||||
- item_type: ACL_RULE
|
||||
id: '45'
|
||||
permission: ALLOW
|
||||
source: 192.168.1.12
|
||||
destination: 192.168.2.16
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 12
|
||||
- item_type: ACL_RULE
|
||||
id: '46'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.10
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 13
|
||||
- item_type: ACL_RULE
|
||||
id: '47'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.14
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 14
|
||||
- item_type: ACL_RULE
|
||||
id: '48'
|
||||
permission: ALLOW
|
||||
source: 192.168.2.16
|
||||
destination: 192.168.1.12
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 15
|
||||
- item_type: ACL_RULE
|
||||
id: '49'
|
||||
permission: DENY
|
||||
source: ANY
|
||||
destination: ANY
|
||||
protocol: ANY
|
||||
port: ANY
|
||||
position: 16
|
||||
- item_type: RED_POL
|
||||
id: '50'
|
||||
start_step: 50
|
||||
end_step: 50
|
||||
targetNodeId: '1'
|
||||
initiator: DIRECT
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_IER
|
||||
id: '51'
|
||||
start_step: 75
|
||||
end_step: 105
|
||||
load: 10000
|
||||
protocol: UDP
|
||||
port: '53'
|
||||
source: '1'
|
||||
destination: '8'
|
||||
mission_criticality: 0
|
||||
- item_type: RED_POL
|
||||
id: '52'
|
||||
start_step: 100
|
||||
end_step: 100
|
||||
targetNodeId: '8'
|
||||
initiator: IER
|
||||
type: SERVICE
|
||||
protocol: UDP
|
||||
state: COMPROMISED
|
||||
sourceNodeId: NA
|
||||
sourceNodeService: NA
|
||||
sourceNodeServiceState: NA
|
||||
- item_type: RED_POL
|
||||
id: '53'
|
||||
start_step: 105
|
||||
end_step: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: FILE
|
||||
protocol: NA
|
||||
state: CORRUPT
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- item_type: RED_POL
|
||||
id: '54'
|
||||
start_step: 105
|
||||
end_step: 105
|
||||
targetNodeId: '8'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP_SQL
|
||||
state: COMPROMISED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: UDP
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
- item_type: RED_POL
|
||||
id: '55'
|
||||
start_step: 125
|
||||
end_step: 125
|
||||
targetNodeId: '7'
|
||||
initiator: SERVICE
|
||||
type: SERVICE
|
||||
protocol: TCP
|
||||
state: OVERWHELMED
|
||||
sourceNodeId: '8'
|
||||
sourceNodeService: TCP_SQL
|
||||
sourceNodeServiceState: COMPROMISED
|
||||
@@ -1,168 +0,0 @@
|
||||
# Training Config File
|
||||
|
||||
# Sets which agent algorithm framework will be used.
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "CUSTOM" (Custom Agent)
|
||||
agent_framework: SB3
|
||||
|
||||
# Sets which deep learning framework will be used (by RLlib ONLY).
|
||||
# Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TF2
|
||||
|
||||
# Sets which Agent class will be used.
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
|
||||
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
|
||||
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
|
||||
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
|
||||
# "RANDOM" (primaite.agents.simple.RandomAgent)
|
||||
# "DUMMY" (primaite.agents.simple.DummyAgent)
|
||||
agent_identifier: PPO
|
||||
|
||||
# Sets whether Red Agent POL and IER is randomised.
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
random_red_agent: False
|
||||
|
||||
# The (integer) seed to be used in random number generation
|
||||
# Default is None (null)
|
||||
seed: null
|
||||
|
||||
# Set whether the agent evaluation will be deterministic instead of stochastic
|
||||
# Options are:
|
||||
# True
|
||||
# False
|
||||
deterministic: False
|
||||
|
||||
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
|
||||
# Options are:
|
||||
# "BASIC" (The current observation space only)
|
||||
# "FULL" (Full environment view with actions taken and reward feedback)
|
||||
hard_coded_agent_view: FULL
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# observation space
|
||||
observation_space:
|
||||
flatten: true
|
||||
components:
|
||||
- name: NODE_LINK_TABLE
|
||||
- name: NODE_STATUSES
|
||||
- name: LINK_TRAFFIC_LEVELS
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Number of episodes for training to run per session
|
||||
num_train_episodes: 10
|
||||
|
||||
# Number of time_steps for training per episode
|
||||
num_train_steps: 256
|
||||
|
||||
# Number of episodes for evaluation to run per session
|
||||
num_eval_episodes: 1
|
||||
|
||||
# Number of time_steps for evaluation per episode
|
||||
num_eval_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 10
|
||||
|
||||
# Time delay (milliseconds) between steps for CUSTOM agents.
|
||||
time_delay: 5
|
||||
|
||||
# Type of session to be run. Options are:
|
||||
# "TRAIN" (Trains an agent)
|
||||
# "EVAL" (Evaluates an agent)
|
||||
# "TRAIN_EVAL" (Trains then evaluates an agent)
|
||||
session_type: TRAIN_EVAL
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observation_space_high_value: 1000000000
|
||||
|
||||
# Implicit ACL firewall rule at end of ACL list to be the default action (ALLOW or DENY)
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 30
|
||||
|
||||
# The Stable Baselines3 learn/eval output verbosity level:
|
||||
# Options are:
|
||||
# "NONE" (No Output)
|
||||
# "INFO" (Info Messages (such as devices and wrappers used))
|
||||
# "DEBUG" (All Messages)
|
||||
sb3_output_verbose_level: NONE
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
off_should_be_on: -0.001
|
||||
off_should_be_resetting: -0.0005
|
||||
on_should_be_off: -0.0002
|
||||
on_should_be_resetting: -0.0005
|
||||
resetting_should_be_on: -0.0005
|
||||
resetting_should_be_off: -0.0002
|
||||
resetting: -0.0003
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: 0.0002
|
||||
good_should_be_compromised: 0.0005
|
||||
good_should_be_overwhelmed: 0.0005
|
||||
patching_should_be_good: -0.0005
|
||||
patching_should_be_compromised: 0.0002
|
||||
patching_should_be_overwhelmed: 0.0002
|
||||
patching: -0.0003
|
||||
compromised_should_be_good: -0.002
|
||||
compromised_should_be_patching: -0.002
|
||||
compromised_should_be_overwhelmed: -0.002
|
||||
compromised: -0.002
|
||||
overwhelmed_should_be_good: -0.002
|
||||
overwhelmed_should_be_patching: -0.002
|
||||
overwhelmed_should_be_compromised: -0.002
|
||||
overwhelmed: -0.002
|
||||
# Node File System State
|
||||
good_should_be_repairing: 0.0002
|
||||
good_should_be_restoring: 0.0002
|
||||
good_should_be_corrupt: 0.0005
|
||||
good_should_be_destroyed: 0.001
|
||||
repairing_should_be_good: -0.0005
|
||||
repairing_should_be_restoring: 0.0002
|
||||
repairing_should_be_corrupt: 0.0002
|
||||
repairing_should_be_destroyed: 0.0000
|
||||
repairing: -0.0003
|
||||
restoring_should_be_good: -0.001
|
||||
restoring_should_be_repairing: -0.0002
|
||||
restoring_should_be_corrupt: 0.0001
|
||||
restoring_should_be_destroyed: 0.0002
|
||||
restoring: -0.0006
|
||||
corrupt_should_be_good: -0.001
|
||||
corrupt_should_be_repairing: -0.001
|
||||
corrupt_should_be_restoring: -0.001
|
||||
corrupt_should_be_destroyed: 0.0002
|
||||
corrupt: -0.001
|
||||
destroyed_should_be_good: -0.002
|
||||
destroyed_should_be_repairing: -0.002
|
||||
destroyed_should_be_restoring: -0.002
|
||||
destroyed_should_be_corrupt: -0.002
|
||||
destroyed: -0.002
|
||||
scanning: -0.0002
|
||||
# IER status
|
||||
red_ier_running: -0.0005
|
||||
green_ier_blocked: -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
@@ -1,141 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
|
||||
|
||||
|
||||
def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a legacy lay down config to the new format.
|
||||
|
||||
:param legacy_config: A legacy lay down config.
|
||||
"""
|
||||
field_conversion_map = {
|
||||
"itemType": "item_type",
|
||||
"portsList": "ports_list",
|
||||
"serviceList": "service_list",
|
||||
"baseType": "node_class",
|
||||
"nodeType": "node_type",
|
||||
"hardwareState": "hardware_state",
|
||||
"softwareState": "software_state",
|
||||
"startStep": "start_step",
|
||||
"endStep": "end_step",
|
||||
"fileSystemState": "file_system_state",
|
||||
"ipAddress": "ip_address",
|
||||
"missionCriticality": "mission_criticality",
|
||||
}
|
||||
new_config = []
|
||||
for item in legacy_config:
|
||||
if "itemType" in item:
|
||||
if item["itemType"] in ["ACTIONS", "STEPS"]:
|
||||
continue
|
||||
new_dict = {}
|
||||
for key in item.keys():
|
||||
conversion_key = field_conversion_map.get(key)
|
||||
if key == "id" and "itemType" in item:
|
||||
if item["itemType"] == "NODE":
|
||||
conversion_key = "node_id"
|
||||
if conversion_key:
|
||||
new_dict[conversion_key] = item[key]
|
||||
else:
|
||||
new_dict[key] = item[key]
|
||||
new_config.append(new_dict)
|
||||
return new_config
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
|
||||
"""
|
||||
Read in a lay down config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise False.
|
||||
:return: The lay down config as a dict.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading lay down config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_lay_down_config(config)
|
||||
except KeyError:
|
||||
msg = (
|
||||
f"Failed to convert lay down config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
return config
|
||||
msg = f"Cannot load the lay down config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_1_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def ddos_basic_two_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_2_DDOS_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_2_DDOS_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def dos_very_basic_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_3_DOS_very_basic.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_3_DOS_very_basic.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def data_manipulation_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_5_data_manipulation.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_LAY_DOWN / "lay_down_config_5_data_manipulation.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
45
src/primaite/config/load.py
Normal file
45
src/primaite/config/load.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Final, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_CFG: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config"
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path]) -> Dict:
|
||||
"""
|
||||
Read a YAML file and return the contents as a dictionary.
|
||||
|
||||
:param file_path: Path to the YAML file.
|
||||
:type file_path: Union[str, Path]
|
||||
:return: Config dictionary
|
||||
:rtype: Dict
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
_LOGGER.error(f"File does not exist: {file_path}")
|
||||
raise FileNotFoundError(f"File does not exist: {file_path}")
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loaded config from {file_path}")
|
||||
return config
|
||||
|
||||
|
||||
def example_config_path() -> Path:
|
||||
"""
|
||||
Get the path to the example config.
|
||||
|
||||
:return: Path to the example config.
|
||||
:rtype: Path
|
||||
"""
|
||||
path = _EXAMPLE_CFG / "example_config.yaml"
|
||||
if not path.exists():
|
||||
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
return path
|
||||
@@ -1,438 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.common.enums import (
|
||||
ActionType,
|
||||
AgentFramework,
|
||||
AgentIdentifier,
|
||||
DeepLearningFramework,
|
||||
HardCodedAgentView,
|
||||
RulePermissionType,
|
||||
SB3OutputVerboseLevel,
|
||||
SessionType,
|
||||
)
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "training"
|
||||
|
||||
|
||||
def main_training_config_path() -> Path:
|
||||
"""
|
||||
The path to the example training_config_main.yaml file.
|
||||
|
||||
:return: The file path.
|
||||
"""
|
||||
path = _EXAMPLE_TRAINING / "training_config_main.yaml"
|
||||
if not path.exists():
|
||||
msg = "Example config not found. Please run 'primaite setup'"
|
||||
_LOGGER.critical(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The AgentFramework"
|
||||
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework"
|
||||
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO
|
||||
"The AgentIdentifier"
|
||||
|
||||
hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL
|
||||
"The view the deterministic hard-coded agent has of the environment"
|
||||
|
||||
random_red_agent: bool = False
|
||||
"Creates Random Red Agent Attacks"
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use"
|
||||
|
||||
num_train_episodes: int = 10
|
||||
"The number of episodes to train over during an training session"
|
||||
|
||||
num_train_steps: int = 256
|
||||
"The number of steps in an episode during an training session"
|
||||
|
||||
num_eval_episodes: int = 1
|
||||
"The number of episodes to train over during an evaluation session"
|
||||
|
||||
num_eval_steps: int = 256
|
||||
"The number of steps in an episode during an evaluation session"
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes"
|
||||
|
||||
observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]})
|
||||
"The observation space config dict"
|
||||
|
||||
time_delay: int = 10
|
||||
"The delay between steps (ms). Applies to generic agents only"
|
||||
|
||||
# file
|
||||
session_type: SessionType = SessionType.TRAIN
|
||||
"The type of PrimAITE session to run"
|
||||
|
||||
load_agent: bool = False
|
||||
"Determine whether to load an agent from file"
|
||||
|
||||
agent_load_file: Optional[str] = None
|
||||
"File path and file name of agent if you're loading one in"
|
||||
|
||||
# Environment
|
||||
observation_space_high_value: int = 1000000000
|
||||
"The high value for the observation space"
|
||||
|
||||
sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE
|
||||
"Stable Baselines3 learn/eval output verbosity level"
|
||||
|
||||
implicit_acl_rule: RulePermissionType = RulePermissionType.DENY
|
||||
"ALLOW or DENY implicit firewall rule to go at the end of list of ACL list."
|
||||
|
||||
max_number_acl_rules: int = 30
|
||||
"Sets a limit for number of acl rules allowed in the list and environment."
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
all_ok: float = 0
|
||||
|
||||
# Node Hardware State
|
||||
off_should_be_on: float = -0.001
|
||||
off_should_be_resetting: float = -0.0005
|
||||
on_should_be_off: float = -0.0002
|
||||
on_should_be_resetting: float = -0.0005
|
||||
resetting_should_be_on: float = -0.0005
|
||||
resetting_should_be_off: float = -0.0002
|
||||
resetting: float = -0.0003
|
||||
|
||||
# Node Software or Service State
|
||||
good_should_be_patching: float = 0.0002
|
||||
good_should_be_compromised: float = 0.0005
|
||||
good_should_be_overwhelmed: float = 0.0005
|
||||
patching_should_be_good: float = -0.0005
|
||||
patching_should_be_compromised: float = 0.0002
|
||||
patching_should_be_overwhelmed: float = 0.0002
|
||||
patching: float = -0.0003
|
||||
compromised_should_be_good: float = -0.002
|
||||
compromised_should_be_patching: float = -0.002
|
||||
compromised_should_be_overwhelmed: float = -0.002
|
||||
compromised: float = -0.002
|
||||
overwhelmed_should_be_good: float = -0.002
|
||||
overwhelmed_should_be_patching: float = -0.002
|
||||
overwhelmed_should_be_compromised: float = -0.002
|
||||
overwhelmed: float = -0.002
|
||||
|
||||
# Node File System State
|
||||
good_should_be_repairing: float = 0.0002
|
||||
good_should_be_restoring: float = 0.0002
|
||||
good_should_be_corrupt: float = 0.0005
|
||||
good_should_be_destroyed: float = 0.001
|
||||
repairing_should_be_good: float = -0.0005
|
||||
repairing_should_be_restoring: float = 0.0002
|
||||
repairing_should_be_corrupt: float = 0.0002
|
||||
repairing_should_be_destroyed: float = 0.0000
|
||||
repairing: float = -0.0003
|
||||
restoring_should_be_good: float = -0.001
|
||||
restoring_should_be_repairing: float = -0.0002
|
||||
restoring_should_be_corrupt: float = 0.0001
|
||||
restoring_should_be_destroyed: float = 0.0002
|
||||
restoring: float = -0.0006
|
||||
corrupt_should_be_good: float = -0.001
|
||||
corrupt_should_be_repairing: float = -0.001
|
||||
corrupt_should_be_restoring: float = -0.001
|
||||
corrupt_should_be_destroyed: float = 0.0002
|
||||
corrupt: float = -0.001
|
||||
destroyed_should_be_good: float = -0.002
|
||||
destroyed_should_be_repairing: float = -0.002
|
||||
destroyed_should_be_restoring: float = -0.002
|
||||
destroyed_should_be_corrupt: float = -0.002
|
||||
destroyed: float = -0.002
|
||||
scanning: float = -0.0002
|
||||
|
||||
# IER status
|
||||
red_ier_running: float = -0.0005
|
||||
green_ier_blocked: float = -0.001
|
||||
|
||||
# Patching / Reset durations
|
||||
os_patching_duration: int = 5
|
||||
"The time taken to patch the OS"
|
||||
|
||||
node_reset_duration: int = 5
|
||||
"The time taken to reset a node (hardware)"
|
||||
|
||||
node_booting_duration: int = 3
|
||||
"The Time taken to turn on the node"
|
||||
|
||||
node_shutdown_duration: int = 2
|
||||
"The time taken to turn off the node"
|
||||
|
||||
service_patching_duration: int = 5
|
||||
"The time taken to patch a service"
|
||||
|
||||
file_system_repairing_limit: int = 5
|
||||
"The time take to repair the file system"
|
||||
|
||||
file_system_restoring_limit: int = 5
|
||||
"The time take to restore the file system"
|
||||
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system"
|
||||
|
||||
deterministic: bool = False
|
||||
"If true, the training will be deterministic"
|
||||
|
||||
seed: Optional[int] = None
|
||||
"The random number generator seed to be used while training the agent"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> TrainingConfig:
|
||||
"""
|
||||
Create an instance of TrainingConfig from a dict.
|
||||
|
||||
:param config_dict: The training config dict.
|
||||
:return: The instance of TrainingConfig.
|
||||
"""
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"agent_identifier": AgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel,
|
||||
"hard_coded_agent_view": HardCodedAgentView,
|
||||
"implicit_acl_rule": RulePermissionType,
|
||||
}
|
||||
|
||||
# convert the string representation of enums into the actual enum values themselves?
|
||||
for key, value in field_enum_map.items():
|
||||
if key in config_dict:
|
||||
config_dict[key] = value[config_dict[key]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True) -> Dict:
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
|
||||
:param json_serializable: If True, Enums are converted to their
|
||||
string name.
|
||||
:return: The ``TrainingConfig`` as a dict.
|
||||
"""
|
||||
data = self.__dict__
|
||||
if json_serializable:
|
||||
data["agent_framework"] = self.agent_framework.name
|
||||
data["deep_learning_framework"] = self.deep_learning_framework.name
|
||||
data["agent_identifier"] = self.agent_identifier.name
|
||||
data["action_type"] = self.action_type.name
|
||||
data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name
|
||||
data["session_type"] = self.session_type.name
|
||||
data["hard_coded_agent_view"] = self.hard_coded_agent_view.name
|
||||
data["implicit_acl_rule"] = self.implicit_acl_rule.name
|
||||
|
||||
return data
|
||||
|
||||
def __str__(self) -> str:
|
||||
obs_str = ",".join([c["name"] for c in self.observation_space["components"]])
|
||||
tc = f"{self.agent_framework}, "
|
||||
# if self.agent_framework is AgentFramework.RLLIB:
|
||||
# tc += f"{self.deep_learning_framework}, "
|
||||
tc += f"{self.agent_identifier}, "
|
||||
if self.agent_identifier is AgentIdentifier.HARDCODED:
|
||||
tc += f"{self.hard_coded_agent_view}, "
|
||||
tc += f"{self.action_type}, "
|
||||
tc += f"observation_space={obs_str}, "
|
||||
if self.session_type is SessionType.TRAIN:
|
||||
tc += f"{self.num_train_episodes} episodes @ "
|
||||
tc += f"{self.num_train_steps} steps"
|
||||
elif self.session_type is SessionType.EVAL:
|
||||
tc += f"{self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
else:
|
||||
tc += f"Training: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
tc += f"Evaluation: {self.num_eval_episodes} episodes @ "
|
||||
tc += f"{self.num_eval_steps} steps"
|
||||
return tc
|
||||
|
||||
|
||||
def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig:
|
||||
"""
|
||||
Read in a training config yaml file.
|
||||
|
||||
:param file_path: The config file path.
|
||||
:param legacy_file: True if the config file is legacy format, otherwise
|
||||
False.
|
||||
:return: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
:raises ValueError: If the file_path does not exist.
|
||||
:raises TypeError: When the TrainingConfig object cannot be created
|
||||
using the values from the config file read from ``file_path``.
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
_LOGGER.debug(f"Loading training config file: {file_path}")
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_training_config_dict(config)
|
||||
|
||||
except KeyError as e:
|
||||
msg = (
|
||||
f"Failed to convert training config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise e
|
||||
try:
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}"
|
||||
_LOGGER.critical(msg, exc_info=True)
|
||||
raise e
|
||||
msg = f"Cannot load the training config as it does not exist: {file_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_train_steps: int = 256,
|
||||
num_eval_steps: int = 256,
|
||||
num_train_episodes: int = 10,
|
||||
num_eval_episodes: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_train_steps: The number of train steps to set as legacy training configs
|
||||
don't have num_train_steps values.
|
||||
:param num_eval_steps: The number of eval steps to set as legacy training configs
|
||||
don't have num_eval_steps values.
|
||||
:param num_train_episodes: The number of train episodes to set as legacy training configs
|
||||
don't have num_train_episodes values.
|
||||
:param num_eval_episodes: The number of eval episodes to set as legacy training configs
|
||||
don't have num_eval_episodes values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"agent_framework": agent_framework.name,
|
||||
"agent_identifier": agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_train_steps": num_train_steps,
|
||||
"num_eval_steps": num_eval_steps,
|
||||
"num_train_episodes": num_train_episodes,
|
||||
"num_eval_episodes": num_eval_episodes,
|
||||
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
|
||||
}
|
||||
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}
|
||||
legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]]
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
if new_key:
|
||||
config_dict[new_key] = value
|
||||
return config_dict
|
||||
|
||||
|
||||
def _get_new_key_from_legacy(legacy_key: str) -> Optional[str]:
|
||||
"""
|
||||
Maps legacy training config keys to the new format keys.
|
||||
|
||||
:param legacy_key: A legacy training config key.
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_train_episodes",
|
||||
"numSteps": "num_train_steps",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
"sessionType": "session_type",
|
||||
"loadAgent": "load_agent",
|
||||
"agentLoadFile": "agent_load_file",
|
||||
"observationSpaceHighValue": "observation_space_high_value",
|
||||
"allOk": "all_ok",
|
||||
"offShouldBeOn": "off_should_be_on",
|
||||
"offShouldBeResetting": "off_should_be_resetting",
|
||||
"onShouldBeOff": "on_should_be_off",
|
||||
"onShouldBeResetting": "on_should_be_resetting",
|
||||
"resettingShouldBeOn": "resetting_should_be_on",
|
||||
"resettingShouldBeOff": "resetting_should_be_off",
|
||||
"resetting": "resetting",
|
||||
"goodShouldBePatching": "good_should_be_patching",
|
||||
"goodShouldBeCompromised": "good_should_be_compromised",
|
||||
"goodShouldBeOverwhelmed": "good_should_be_overwhelmed",
|
||||
"patchingShouldBeGood": "patching_should_be_good",
|
||||
"patchingShouldBeCompromised": "patching_should_be_compromised",
|
||||
"patchingShouldBeOverwhelmed": "patching_should_be_overwhelmed",
|
||||
"patching": "patching",
|
||||
"compromisedShouldBeGood": "compromised_should_be_good",
|
||||
"compromisedShouldBePatching": "compromised_should_be_patching",
|
||||
"compromisedShouldBeOverwhelmed": "compromised_should_be_overwhelmed",
|
||||
"compromised": "compromised",
|
||||
"overwhelmedShouldBeGood": "overwhelmed_should_be_good",
|
||||
"overwhelmedShouldBePatching": "overwhelmed_should_be_patching",
|
||||
"overwhelmedShouldBeCompromised": "overwhelmed_should_be_compromised",
|
||||
"overwhelmed": "overwhelmed",
|
||||
"goodShouldBeRepairing": "good_should_be_repairing",
|
||||
"goodShouldBeRestoring": "good_should_be_restoring",
|
||||
"goodShouldBeCorrupt": "good_should_be_corrupt",
|
||||
"goodShouldBeDestroyed": "good_should_be_destroyed",
|
||||
"repairingShouldBeGood": "repairing_should_be_good",
|
||||
"repairingShouldBeRestoring": "repairing_should_be_restoring",
|
||||
"repairingShouldBeCorrupt": "repairing_should_be_corrupt",
|
||||
"repairingShouldBeDestroyed": "repairing_should_be_destroyed",
|
||||
"repairing": "repairing",
|
||||
"restoringShouldBeGood": "restoring_should_be_good",
|
||||
"restoringShouldBeRepairing": "restoring_should_be_repairing",
|
||||
"restoringShouldBeCorrupt": "restoring_should_be_corrupt",
|
||||
"restoringShouldBeDestroyed": "restoring_should_be_destroyed",
|
||||
"restoring": "restoring",
|
||||
"corruptShouldBeGood": "corrupt_should_be_good",
|
||||
"corruptShouldBeRepairing": "corrupt_should_be_repairing",
|
||||
"corruptShouldBeRestoring": "corrupt_should_be_restoring",
|
||||
"corruptShouldBeDestroyed": "corrupt_should_be_destroyed",
|
||||
"corrupt": "corrupt",
|
||||
"destroyedShouldBeGood": "destroyed_should_be_good",
|
||||
"destroyedShouldBeRepairing": "destroyed_should_be_repairing",
|
||||
"destroyedShouldBeRestoring": "destroyed_should_be_restoring",
|
||||
"destroyedShouldBeCorrupt": "destroyed_should_be_corrupt",
|
||||
"destroyed": "destroyed",
|
||||
"scanning": "scanning",
|
||||
"redIerRunning": "red_ier_running",
|
||||
"greenIerBlocked": "green_ier_blocked",
|
||||
"osPatchingDuration": "os_patching_duration",
|
||||
"nodeResetDuration": "node_reset_duration",
|
||||
"nodeBootingDuration": "node_booting_duration",
|
||||
"nodeShutdownDuration": "node_shutdown_duration",
|
||||
"servicePatchingDuration": "service_patching_duration",
|
||||
"fileSystemRepairingLimit": "file_system_repairing_limit",
|
||||
"fileSystemRestoringLimit": "file_system_restoring_limit",
|
||||
"fileSystemScanningLimit": "file_system_scanning_limit",
|
||||
}
|
||||
return key_mapping[legacy_key]
|
||||
@@ -1,15 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Utility to generate plots of sessions metrics after PrimAITE."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PlotlyTemplate(Enum):
|
||||
"""The built-in plotly templates."""
|
||||
|
||||
PLOTLY = "plotly"
|
||||
PLOTLY_WHITE = "plotly_white"
|
||||
PLOTLY_DARK = "plotly_dark"
|
||||
GGPLOT2 = "ggplot2"
|
||||
SEABORN = "seaborn"
|
||||
SIMPLE_WHITE = "simple_white"
|
||||
NONE = "none"
|
||||
@@ -1,73 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import polars as pl
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
|
||||
def get_plotly_config() -> Dict:
|
||||
"""Get the plotly config from primaite_config.yaml."""
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["session"]["outputs"]["plots"]
|
||||
|
||||
|
||||
def plot_av_reward_per_episode(
|
||||
av_reward_per_episode_csv: Union[str, Path],
|
||||
title: Optional[str] = None,
|
||||
subtitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""
|
||||
Plot the average reward per episode from a csv session output.
|
||||
|
||||
:param av_reward_per_episode_csv: The average reward per episode csv
|
||||
file path.
|
||||
:param title: The plot title. This is optional.
|
||||
:param subtitle: The plot subtitle. This is optional.
|
||||
:return: The plot as an instance of ``plotly.graph_objs._figure.Figure``.
|
||||
"""
|
||||
df = pl.read_csv(av_reward_per_episode_csv)
|
||||
|
||||
if title:
|
||||
if subtitle:
|
||||
title = f"{title} <br>{subtitle}</sup>"
|
||||
else:
|
||||
if subtitle:
|
||||
title = subtitle
|
||||
|
||||
config = get_plotly_config()
|
||||
layout = go.Layout(
|
||||
autosize=config["size"]["auto_size"],
|
||||
width=config["size"]["width"],
|
||||
height=config["size"]["height"],
|
||||
)
|
||||
# Create the line graph with a colored line
|
||||
fig = go.Figure(layout=layout)
|
||||
fig.update_layout(template=config["template"])
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=df["Episode"],
|
||||
y=df["Average Reward"],
|
||||
mode="lines",
|
||||
name="Mean Reward per Episode",
|
||||
)
|
||||
)
|
||||
|
||||
# Set the layout of the graph
|
||||
fig.update_layout(
|
||||
xaxis={
|
||||
"title": "Episode",
|
||||
"type": "linear",
|
||||
"rangeslider": {"visible": config["range_slider"]},
|
||||
},
|
||||
yaxis={"title": "Average Reward"},
|
||||
title=title,
|
||||
showlegend=False,
|
||||
)
|
||||
|
||||
return fig
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""
|
||||
@@ -1,735 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.common.enums import FileSystemState, HardwareState, RulePermissionType, SoftwareState
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
# This dependency is only needed for type hints,
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise observation component.
|
||||
|
||||
:param env: Primaite training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
_LOGGER.info(f"Initialising {self} observation component")
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
self.structure: List[str]
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on the current state of the environment."""
|
||||
self.current_observation = NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NodeLinkTable(AbstractObservationComponent):
|
||||
"""
|
||||
Table with nodes and links as rows and hardware/software status as cols.
|
||||
|
||||
This will create the observation space formatted as a table of integers.
|
||||
There is one row per node, followed by one row per link.
|
||||
The number of columns is 4 plus one per service. They are:
|
||||
|
||||
* node/link ID
|
||||
* node hardware status / 0 for links
|
||||
* node operating system status (if active/service) / 0 for links
|
||||
* node file system status (active/service only) / 0 for links
|
||||
* node service1 status / traffic load from that service for links
|
||||
* node service2 status / traffic load from that service for links
|
||||
* ...
|
||||
* node serviceN status / traffic load from that service for links
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(12, 7)``
|
||||
"""
|
||||
|
||||
_FIXED_PARAMETERS: int = 4
|
||||
_MAX_VAL: int = 1_000_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeLinkTable observation space component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
num_items = self.env.num_links + self.env.num_nodes
|
||||
num_columns = self.env.num_services + self._FIXED_PARAMETERS
|
||||
observation_shape = (num_items, num_columns)
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.Box(
|
||||
low=0,
|
||||
high=self._MAX_VAL,
|
||||
shape=observation_shape,
|
||||
dtype=self._DATA_TYPE,
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||
"""
|
||||
item_index = 0
|
||||
nodes = self.env.nodes
|
||||
links = self.env.links
|
||||
# Do nodes first
|
||||
for _, node in nodes.items():
|
||||
self.current_observation[item_index][0] = int(node.node_id)
|
||||
self.current_observation[item_index][1] = node.hardware_state.value
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
self.current_observation[item_index][2] = node.software_state.value
|
||||
self.current_observation[item_index][3] = node.file_system_state_observed.value
|
||||
else:
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
service_index = 4
|
||||
if isinstance(node, ServiceNode):
|
||||
for service in self.env.services_list:
|
||||
if node.has_service(service):
|
||||
self.current_observation[item_index][service_index] = node.get_service_state(service).value
|
||||
else:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
else:
|
||||
# Not a service node
|
||||
for service in self.env.services_list:
|
||||
self.current_observation[item_index][service_index] = 0
|
||||
service_index += 1
|
||||
item_index += 1
|
||||
|
||||
# Now do links
|
||||
for _, link in links.items():
|
||||
self.current_observation[item_index][0] = int(link.get_id())
|
||||
self.current_observation[item_index][1] = 0
|
||||
self.current_observation[item_index][2] = 0
|
||||
self.current_observation[item_index][3] = 0
|
||||
protocol_list = link.get_protocol_list()
|
||||
protocol_index = 0
|
||||
for protocol in protocol_list:
|
||||
self.current_observation[item_index][protocol_index + 4] = protocol.get_load()
|
||||
protocol_index += 1
|
||||
item_index += 1
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
nodes = self.env.nodes.values()
|
||||
links = self.env.links.values()
|
||||
|
||||
structure = []
|
||||
|
||||
for i, node in enumerate(nodes):
|
||||
node_id = node.node_id
|
||||
node_labels = [
|
||||
f"node_{node_id}_id",
|
||||
f"node_{node_id}_hardware_status",
|
||||
f"node_{node_id}_os_status",
|
||||
f"node_{node_id}_fs_status",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
node_labels.append(f"node_{node_id}_service_{serv}_status")
|
||||
|
||||
structure.extend(node_labels)
|
||||
|
||||
for i, link in enumerate(links):
|
||||
link_id = link.id
|
||||
link_labels = [
|
||||
f"link_{link_id}_id",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
f"link_{link_id}_n/a",
|
||||
]
|
||||
for j, serv in enumerate(self.env.services_list):
|
||||
link_labels.append(f"link_{link_id}_service_{serv}_load")
|
||||
|
||||
structure.extend(link_labels)
|
||||
return structure
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""
|
||||
Flat list of nodes' hardware, OS, file system, and service states.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
Each node has 3 elements plus 1 per service. It will have the following structure:
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
node1 hardware state,
|
||||
node1 OS state,
|
||||
node1 file system state,
|
||||
node1 service1 state,
|
||||
node1 service2 state,
|
||||
node1 serviceN state (one for each service),
|
||||
node2 hardware state,
|
||||
node2 OS state,
|
||||
node2 file system state,
|
||||
node2 service1 state,
|
||||
node2 service2 state,
|
||||
node2 serviceN state (one for each service),
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise a NodeStatuses observation component.
|
||||
|
||||
:param env: Training environment.
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
node_shape = [
|
||||
len(HardwareState) + 1,
|
||||
len(SoftwareState) + 1,
|
||||
len(FileSystemState) + 1,
|
||||
]
|
||||
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
||||
node_shape = node_shape + services_shape
|
||||
|
||||
shape = node_shape * self.env.num_nodes
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||
"""
|
||||
obs = []
|
||||
for _, node in self.env.nodes.items():
|
||||
hardware_state = node.hardware_state.value
|
||||
software_state = 0
|
||||
file_system_state = 0
|
||||
service_states = [0] * self.env.num_services
|
||||
|
||||
if isinstance(node, ActiveNode):
|
||||
software_state = node.software_state.value
|
||||
file_system_state = node.file_system_state_observed.value
|
||||
|
||||
if isinstance(node, ServiceNode):
|
||||
for i, service in enumerate(self.env.services_list):
|
||||
if node.has_service(service):
|
||||
service_states[i] = node.get_service_state(service).value
|
||||
obs.extend(
|
||||
[
|
||||
hardware_state,
|
||||
software_state,
|
||||
file_system_state,
|
||||
*service_states,
|
||||
]
|
||||
)
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
services = self.env.services_list
|
||||
|
||||
structure = []
|
||||
|
||||
for _, node in self.env.nodes.items():
|
||||
node_id = node.node_id
|
||||
structure.append(f"node_{node_id}_hardware_state_NONE")
|
||||
for state in HardwareState:
|
||||
structure.append(f"node_{node_id}_hardware_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_software_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_software_state_{state.name}")
|
||||
structure.append(f"node_{node_id}_file_system_state_NONE")
|
||||
for state in FileSystemState:
|
||||
structure.append(f"node_{node_id}_file_system_state_{state.name}")
|
||||
for service in services:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_NONE")
|
||||
for state in SoftwareState:
|
||||
structure.append(f"node_{node_id}_service_{service}_state_{state.name}")
|
||||
return structure
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""
|
||||
Flat list of traffic levels encoded into banded categories.
|
||||
|
||||
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||
|
||||
* 0 = No traffic (0% of bandwidth)
|
||||
* 1 = No traffic (0%-33% of bandwidth)
|
||||
* 2 = No traffic (33%-66% of bandwidth)
|
||||
* 3 = No traffic (66%-100% of bandwidth)
|
||||
* 4 = No traffic (100% of bandwidth)
|
||||
|
||||
.. note::
|
||||
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
|
||||
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
|
||||
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a LinkTrafficLevels observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
|
||||
defaults to False
|
||||
:type combine_service_traffic: bool, optional
|
||||
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical
|
||||
value, defaults to 5
|
||||
:type quantisation_levels: int, optional
|
||||
"""
|
||||
if quantisation_levels < 3:
|
||||
_msg = (
|
||||
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
||||
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
||||
f"Resetting to default value (5)"
|
||||
)
|
||||
_LOGGER.warning(_msg)
|
||||
quantisation_levels = 5
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
self._combine_service_traffic: bool = combine_service_traffic
|
||||
self._quantisation_levels: int = quantisation_levels
|
||||
self._entries_per_link: int = 1
|
||||
|
||||
if not self._combine_service_traffic:
|
||||
self._entries_per_link = self.env.num_services
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""
|
||||
Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||
"""
|
||||
obs = []
|
||||
for _, link in self.env.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
if self._combine_service_traffic:
|
||||
loads = [link.get_current_load()]
|
||||
else:
|
||||
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||
|
||||
for load in loads:
|
||||
if load <= 0:
|
||||
traffic_level = 0
|
||||
elif load >= bandwidth:
|
||||
traffic_level = self._quantisation_levels - 1
|
||||
else:
|
||||
traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1
|
||||
|
||||
obs.append(int(traffic_level))
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for _, link in self.env.links.items():
|
||||
link_id = link.id
|
||||
if self._combine_service_traffic:
|
||||
protocols = ["overall"]
|
||||
else:
|
||||
protocols = [protocol.name for protocol in link.protocol_list]
|
||||
|
||||
for p in protocols:
|
||||
for i in range(self._quantisation_levels):
|
||||
structure.append(f"link_{link_id}_{p}_traffic_level_{i}")
|
||||
return structure
|
||||
|
||||
|
||||
class AccessControlList(AbstractObservationComponent):
|
||||
"""Flat list of all the Access Control Rules in the Access Control List.
|
||||
|
||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||
integers.
|
||||
|
||||
Each ACL Rule has 6 elements. It will have the following structure:
|
||||
.. code-block::
|
||||
[
|
||||
acl_rule1 permission,
|
||||
acl_rule1 source_ip,
|
||||
acl_rule1 dest_ip,
|
||||
acl_rule1 protocol,
|
||||
acl_rule1 port,
|
||||
acl_rule1 position,
|
||||
acl_rule2 permission,
|
||||
acl_rule2 source_ip,
|
||||
acl_rule2 dest_ip,
|
||||
acl_rule2 protocol,
|
||||
acl_rule2 port,
|
||||
acl_rule2 position,
|
||||
...
|
||||
]
|
||||
|
||||
|
||||
Terms (for ACL Observation Space):
|
||||
[0, 1, 2] - Permission (0 = NA, 1 = DENY, 2 = ALLOW)
|
||||
[0, num nodes] - Source IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num nodes] - Dest IP (0 = NA, 1 = any, then 2 -> x resolving to Node IDs)
|
||||
[0, num services] - Protocol (0 = NA, 1 = any, then 2 -> x resolving to protocol)
|
||||
[0, num ports] - Port (0 = NA, 1 = any, then 2 -> x resolving to port)
|
||||
[0, max acl rules - 1] - Position (0 = NA, 1 = first index, then 2 -> x index resolving to acl rule in acl list)
|
||||
|
||||
NOTE: NA is Non-Applicable - this means the ACL Rule in the list is a NoneType and NOT an ACLRule object.
|
||||
"""
|
||||
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite") -> None:
|
||||
"""
|
||||
Initialise an AccessControlList observation component.
|
||||
|
||||
:param env: The environment that forms the basis of the observations
|
||||
:type env: Primaite
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
# The NA and ANY types means that there are 2 extra items for Nodes, Services and Ports.
|
||||
# Number of ACL rules incremented by 1 for positions starting at index 0.
|
||||
acl_shape = [
|
||||
len(RulePermissionType),
|
||||
len(env.nodes) + 2,
|
||||
len(env.nodes) + 2,
|
||||
len(env.services_list) + 2,
|
||||
len(env.ports_list) + 2,
|
||||
env.max_number_acl_rules,
|
||||
]
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
self.structure = self.generate_structure()
|
||||
|
||||
def update(self) -> None:
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
The structure of the observation space is described in :class:`.AccessControlList`
|
||||
"""
|
||||
obs = []
|
||||
|
||||
for index in range(0, len(self.env.acl.acl)):
|
||||
acl_rule = self.env.acl.acl[index]
|
||||
if isinstance(acl_rule, ACLRule):
|
||||
permission = acl_rule.permission
|
||||
source_ip = acl_rule.source_ip
|
||||
dest_ip = acl_rule.dest_ip
|
||||
protocol = acl_rule.protocol
|
||||
port = acl_rule.port
|
||||
position = index
|
||||
# Map each ACL attribute from what it was to an integer to fit the observation space
|
||||
source_ip_int = None
|
||||
dest_ip_int = None
|
||||
if permission == RulePermissionType.DENY:
|
||||
permission_int = 1
|
||||
else:
|
||||
permission_int = 2
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to source IP address
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == source_ip:
|
||||
source_ip_int = int(node.node_id) + 1
|
||||
break
|
||||
if dest_ip == "ANY":
|
||||
dest_ip_int = 1
|
||||
else:
|
||||
# Map Node ID (+ 1) to dest IP address
|
||||
# Index of Nodes start at 1 so + 1 is needed so NA can be added.
|
||||
nodes = list(self.env.nodes.values())
|
||||
for node in nodes:
|
||||
if (
|
||||
isinstance(node, ServiceNode) or isinstance(node, ActiveNode)
|
||||
) and node.ip_address == dest_ip:
|
||||
dest_ip_int = int(node.node_id) + 1
|
||||
if protocol == "ANY":
|
||||
protocol_int = 1
|
||||
else:
|
||||
# Index of protocols and ports start from 0 so + 2 is needed to add NA and ANY
|
||||
try:
|
||||
protocol_int = self.env.services_list.index(protocol) + 2
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = None
|
||||
if port == "ANY":
|
||||
port_int = 1
|
||||
else:
|
||||
if port in self.env.ports_list:
|
||||
port_int = self.env.ports_list.index(port) + 2
|
||||
else:
|
||||
_LOGGER.info(f"Port {port} could not be found.")
|
||||
port_int = None
|
||||
# Add to current obs
|
||||
obs.extend(
|
||||
[
|
||||
permission_int,
|
||||
source_ip_int,
|
||||
dest_ip_int,
|
||||
protocol_int,
|
||||
port_int,
|
||||
position,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
# The Nothing or NA representation of 'NONE' ACL rules
|
||||
obs.extend([0, 0, 0, 0, 0, 0])
|
||||
|
||||
self.current_observation[:] = obs
|
||||
|
||||
def generate_structure(self) -> List[str]:
|
||||
"""Return a list of labels for the components of the flattened observation space."""
|
||||
structure = []
|
||||
for acl_rule in self.env.acl.acl:
|
||||
acl_rule_id = self.env.acl.acl.index(acl_rule)
|
||||
|
||||
for permission in RulePermissionType:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_permission_{permission.name}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_source_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_ANY")
|
||||
for node in self.env.nodes.keys():
|
||||
structure.append(f"acl_rule_{acl_rule_id}_dest_ip_{node}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_ANY")
|
||||
for service in self.env.services_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_service_{service}")
|
||||
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_ANY")
|
||||
for port in self.env.ports_list:
|
||||
structure.append(f"acl_rule_{acl_rule_id}_port_{port}")
|
||||
|
||||
return structure
|
||||
|
||||
|
||||
class ObservationsHandler:
|
||||
"""
|
||||
Component-based observation space handler.
|
||||
|
||||
This allows users to configure observation spaces by mixing and matching components. Each component can also define
|
||||
further parameters to make them more flexible.
|
||||
"""
|
||||
|
||||
_REGISTRY: Final[Dict[str, type]] = {
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
"ACCESS_CONTROL_LIST": AccessControlList,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the observation handler."""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
|
||||
# internal the observation space (unflattened version of space if flatten=True)
|
||||
self._space: spaces.Space
|
||||
# flattened version of the observation space
|
||||
self._flat_space: spaces.Space
|
||||
|
||||
self._observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||
# used for transactions and when flatten=true
|
||||
self._flat_observation: np.ndarray
|
||||
|
||||
def update_obs(self) -> None:
|
||||
"""Fetch fresh information about the environment."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
|
||||
if len(current_obs) == 1:
|
||||
self._observation = current_obs[0]
|
||||
else:
|
||||
self._observation = tuple(current_obs)
|
||||
self._flat_observation = spaces.flatten(self._space, self._observation)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Add a component for this handler to track.
|
||||
|
||||
:param obs_component: The component to add.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent) -> None:
|
||||
"""
|
||||
Remove a component from this handler.
|
||||
|
||||
:param obs_component: Which component to remove. It must exist within this object's
|
||||
``registered_obs_components`` attribute.
|
||||
:type obs_component: AbstractObservationComponent
|
||||
"""
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self) -> None:
|
||||
"""Rebuild the handler's composite observation space from its components."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
|
||||
# if there are multiple components, build a composite tuple space
|
||||
if len(component_spaces) == 1:
|
||||
self._space = component_spaces[0]
|
||||
else:
|
||||
self._space = spaces.Tuple(component_spaces)
|
||||
if len(component_spaces) > 0:
|
||||
self._flat_space = spaces.flatten_space(self._space)
|
||||
else:
|
||||
self._flat_space = spaces.Box(0, 1, (0,))
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Observation space, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_space
|
||||
else:
|
||||
return self._space
|
||||
|
||||
@property
|
||||
def current_observation(self) -> Union[np.ndarray, Tuple[np.ndarray]]:
|
||||
"""Current observation, return the flattened version if flatten is True."""
|
||||
if len(self.registered_obs_components) > 1:
|
||||
return self._flat_observation
|
||||
else:
|
||||
return self._observation
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict) -> "ObservationsHandler":
|
||||
"""
|
||||
Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||
|
||||
The expected format for the config dictionary is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = {
|
||||
components: [
|
||||
{
|
||||
"name": "<COMPONENT1_NAME>"
|
||||
},
|
||||
{
|
||||
"name": "<COMPONENT2_NAME>"
|
||||
"options": {"opt1": val1, "opt2": val2}
|
||||
},
|
||||
{
|
||||
...
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
:return: Observation handler
|
||||
:rtype: primaite.environment.observations.ObservationsHandler
|
||||
"""
|
||||
# Instantiate the handler
|
||||
handler = cls()
|
||||
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
comp_class = cls._REGISTRY[comp_type]
|
||||
|
||||
# Create the component with options from the YAML
|
||||
options = component_cfg.get("options") or {}
|
||||
component = comp_class(env, **options)
|
||||
|
||||
handler.register(component)
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
def describe_structure(self) -> List[str]:
|
||||
"""
|
||||
Create a list of names for the features of the obs space.
|
||||
|
||||
The order of labels follows the flattened version of the space.
|
||||
"""
|
||||
# as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have
|
||||
# to fake it. each component has to just hard-code the expected label order after flattening...
|
||||
|
||||
labels = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
labels.extend(obs_comp.structure)
|
||||
|
||||
return labels
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,386 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements reward function."""
|
||||
from logging import Logger
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER: Logger = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes: Dict[str, NodeUnion],
|
||||
final_nodes: Dict[str, NodeUnion],
|
||||
reference_nodes: Dict[str, NodeUnion],
|
||||
green_iers: Dict[str, "IER"],
|
||||
green_iers_reference: Dict[str, "IER"],
|
||||
red_iers: Dict[str, "IER"],
|
||||
step_count: int,
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Compares the states of the initial and final nodes/links to get a reward.
|
||||
|
||||
Args:
|
||||
initial_nodes: The nodes before red and blue agents take effect
|
||||
final_nodes: The nodes after red and blue agents take effect
|
||||
reference_nodes: The nodes if there had been no red or blue effect
|
||||
green_iers: The green IERs (should be running)
|
||||
red_iers: Should be stopeed (ideally) by the blue agent
|
||||
step_count: current step
|
||||
config_values: Config values
|
||||
"""
|
||||
reward_value: float = 0.0
|
||||
|
||||
# For each node, compare hardware state, SoftwareState, service states
|
||||
for node_key, final_node in final_nodes.items():
|
||||
initial_node = initial_nodes[node_key]
|
||||
reference_node = reference_nodes[node_key]
|
||||
|
||||
# Hardware State
|
||||
reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Software State
|
||||
if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Service State
|
||||
if isinstance(final_node, ServiceNode):
|
||||
reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# File System State
|
||||
if isinstance(final_node, ActiveNode):
|
||||
reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values)
|
||||
|
||||
# Go through each red IER - penalise if it is running
|
||||
for ier_key, ier_value in red_iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if ier_value.get_is_running():
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
return reward_value
|
||||
|
||||
|
||||
def score_node_operating_state(
|
||||
final_node: NodeUnion, initial_node: NodeUnion, reference_node: NodeUnion, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the hardware state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_operating_state = final_node.hardware_state
|
||||
reference_node_operating_state = reference_node.hardware_state
|
||||
|
||||
if final_node_operating_state == reference_node_operating_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final (current) state of node (i.e. at every step)
|
||||
if reference_node_operating_state == HardwareState.ON:
|
||||
if final_node_operating_state == HardwareState.OFF:
|
||||
score += config_values.off_should_be_on
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting_should_be_on
|
||||
else:
|
||||
pass
|
||||
elif reference_node_operating_state == HardwareState.OFF:
|
||||
if final_node_operating_state == HardwareState.ON:
|
||||
score += config_values.on_should_be_off
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting_should_be_off
|
||||
else:
|
||||
pass
|
||||
elif reference_node_operating_state == HardwareState.RESETTING:
|
||||
if final_node_operating_state == HardwareState.ON:
|
||||
score += config_values.on_should_be_resetting
|
||||
elif final_node_operating_state == HardwareState.OFF:
|
||||
score += config_values.off_should_be_resetting
|
||||
elif final_node_operating_state == HardwareState.RESETTING:
|
||||
score += config_values.resetting
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_os_state(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the Software State of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_os_state = final_node.software_state
|
||||
reference_node_os_state = reference_node.software_state
|
||||
|
||||
if final_node_os_state == reference_node_os_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final (current) state of node (i.e. at every step)
|
||||
if reference_node_os_state == SoftwareState.GOOD:
|
||||
if final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_node_os_state == SoftwareState.PATCHING:
|
||||
if final_node_os_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif reference_node_os_state == SoftwareState.COMPROMISED:
|
||||
if final_node_os_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_node_os_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_node_os_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_service_state(
|
||||
final_node: ServiceNode, initial_node: ServiceNode, reference_node: ServiceNode, config_values: "TrainingConfig"
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the service state(s) of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
config_values: Config values
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_services: Dict[str, Service] = final_node.services
|
||||
reference_node_services: Dict[str, Service] = reference_node.services
|
||||
|
||||
for service_key, final_service in final_node_services.items():
|
||||
reference_service = reference_node_services[service_key]
|
||||
final_service = final_node_services[service_key]
|
||||
|
||||
if final_service.software_state == reference_service.software_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final state of node (i.e. at every step)
|
||||
if reference_service.software_state == SoftwareState.GOOD:
|
||||
if final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_good
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_good
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.PATCHING:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_patching
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.COMPROMISED:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_compromised
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_compromised
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed_should_be_compromised
|
||||
else:
|
||||
pass
|
||||
elif reference_service.software_state == SoftwareState.OVERWHELMED:
|
||||
if final_service.software_state == SoftwareState.GOOD:
|
||||
score += config_values.good_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.PATCHING:
|
||||
score += config_values.patching_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.COMPROMISED:
|
||||
score += config_values.compromised_should_be_overwhelmed
|
||||
elif final_service.software_state == SoftwareState.OVERWHELMED:
|
||||
score += config_values.overwhelmed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def score_node_file_system(
|
||||
final_node: Union[ActiveNode, ServiceNode],
|
||||
initial_node: Union[ActiveNode, ServiceNode],
|
||||
reference_node: Union[ActiveNode, ServiceNode],
|
||||
config_values: "TrainingConfig",
|
||||
) -> float:
|
||||
"""
|
||||
Calculates score relating to the file system state of a node.
|
||||
|
||||
Args:
|
||||
final_node: The node after red and blue agents take effect
|
||||
initial_node: The node before red and blue agents take effect
|
||||
reference_node: The node if there had been no red or blue effect
|
||||
"""
|
||||
score: float = 0.0
|
||||
final_node_file_system_state = final_node.file_system_state_actual
|
||||
reference_node_file_system_state = reference_node.file_system_state_actual
|
||||
|
||||
final_node_scanning_state = final_node.file_system_scanning
|
||||
reference_node_scanning_state = reference_node.file_system_scanning
|
||||
|
||||
# File System State
|
||||
if final_node_file_system_state == reference_node_file_system_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# Need to compare reference and final state of node (i.e. at every step)
|
||||
if reference_node_file_system_state == FileSystemState.GOOD:
|
||||
if final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_good
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_good
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.REPAIRING:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_repairing
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.RESTORING:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_restoring
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.CORRUPT:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed_should_be_corrupt
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt
|
||||
else:
|
||||
pass
|
||||
elif reference_node_file_system_state == FileSystemState.DESTROYED:
|
||||
if final_node_file_system_state == FileSystemState.GOOD:
|
||||
score += config_values.good_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.REPAIRING:
|
||||
score += config_values.repairing_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.RESTORING:
|
||||
score += config_values.restoring_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.CORRUPT:
|
||||
score += config_values.corrupt_should_be_destroyed
|
||||
elif final_node_file_system_state == FileSystemState.DESTROYED:
|
||||
score += config_values.destroyed
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Scanning State
|
||||
if final_node_scanning_state == reference_node_scanning_state:
|
||||
# All is well - we're no different from the reference situation
|
||||
score += config_values.all_ok
|
||||
else:
|
||||
# We're different from the reference situation
|
||||
# We're scanning the file system which incurs a penalty (as it slows down systems)
|
||||
score += config_values.scanning
|
||||
|
||||
return score
|
||||
@@ -5,12 +5,6 @@ class PrimaiteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RLlibAgentError(PrimaiteError):
|
||||
"""Raised when there is a generic error with a RLlib agent that is specific to PRimAITE."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(PrimaiteError):
|
||||
"""Raised when an error occurs at the network level."""
|
||||
|
||||
|
||||
1
src/primaite/game/__init__.py
Normal file
1
src/primaite/game/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""PrimAITE Game Layer."""
|
||||
31
src/primaite/game/agent/GATE_agents.py
Normal file
31
src/primaite/game/agent/GATE_agents.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from src.primaite.game.agent.actions import ActionManager
|
||||
from src.primaite.game.agent.interface import AbstractGATEAgent, ObsType
|
||||
from src.primaite.game.agent.observations import ObservationSpace
|
||||
from src.primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GATERLAgent(AbstractGATEAgent):
|
||||
...
|
||||
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
|
||||
# because when we are supporting MARL, the actions form multiple agents will have to be batched
|
||||
|
||||
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
|
||||
# with the actions for those agents.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None,
|
||||
action_space: ActionManager | None,
|
||||
observation_space: ObservationSpace | None,
|
||||
reward_function: RewardFunction | None,
|
||||
) -> None:
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.most_recent_action
|
||||
0
src/primaite/game/agent/__init__.py
Normal file
0
src/primaite/game/agent/__init__.py
Normal file
866
src/primaite/game/agent/actions.py
Normal file
866
src/primaite/game/agent/actions.py
Normal file
@@ -0,0 +1,866 @@
|
||||
"""
|
||||
This module contains the ActionManager class which belongs to the Agent class.
|
||||
|
||||
An agent's action space is made up of a collection of actions. Each action is an instance of a subclass of
|
||||
AbstractAction. The ActionManager is responsible for:
|
||||
1. Creating the action space from a list of action types.
|
||||
2. Converting an integer action choice into a specific action and parameter choice.
|
||||
3. Converting an action and parameter choice into a request which can be ingested by the PrimAITE simulation. This
|
||||
ensures that requests conform to the simulator's request format.
|
||||
"""
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from src.primaite.simulator.sim_container import Simulation
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class AbstractAction(ABC):
|
||||
"""Base class for actions."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
"""
|
||||
Init method for action.
|
||||
|
||||
All action init functions should accept **kwargs as a way of ignoring extra arguments.
|
||||
|
||||
Since many parameters are defined for the action space as a whole (such as max files per folder, max services
|
||||
per node), we need to pass those options to every action that gets created. To pervent verbosity, these
|
||||
parameters are just broadcasted to all actions and the actions can pay attention to the ones that apply.
|
||||
"""
|
||||
self.name: str = ""
|
||||
"""Human-readable action identifier used for printing, logging, and reporting."""
|
||||
self.shape: Dict[str, int] = {}
|
||||
"""Dictionary describing the number of options for each parameter of this action. The keys of this dict must
|
||||
align with the keyword args of the form_request method."""
|
||||
self.manager: ActionManager = manager
|
||||
"""Reference to the ActionManager which created this action. This is used to access the session and simulation
|
||||
objects."""
|
||||
|
||||
@abstractmethod
|
||||
def form_request(self) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return []
|
||||
|
||||
|
||||
class DoNothingAction(AbstractAction):
|
||||
"""Action which does nothing. This is here to allow agents to be idle if they choose to."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.name = "DONOTHING"
|
||||
self.shape: Dict[str, int] = {
|
||||
"dummy": 1,
|
||||
}
|
||||
# This action does not accept any parameters, therefore it technically has a gymnasium shape of Discrete(1),
|
||||
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
|
||||
# with one option. This just aids the Action Manager to enumerate all possibilities.
|
||||
|
||||
def form_request(self, **kwargs) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["do_nothing"]
|
||||
|
||||
|
||||
class NodeServiceAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for service actions.
|
||||
|
||||
Any action which applies to a service and uses node_id and service_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, service_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
service_uuid = self.manager.get_service_uuid_by_idx(node_id, service_id)
|
||||
if node_uuid is None or service_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "services", service_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeServiceScanAction(NodeServiceAbstractAction):
|
||||
"""Action which scans a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeServiceStopAction(NodeServiceAbstractAction):
|
||||
"""Action which stops a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "stop"
|
||||
|
||||
|
||||
class NodeServiceStartAction(NodeServiceAbstractAction):
|
||||
"""Action which starts a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "start"
|
||||
|
||||
|
||||
class NodeServicePauseAction(NodeServiceAbstractAction):
|
||||
"""Action which pauses a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "pause"
|
||||
|
||||
|
||||
class NodeServiceResumeAction(NodeServiceAbstractAction):
|
||||
"""Action which resumes a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "resume"
|
||||
|
||||
|
||||
class NodeServiceRestartAction(NodeServiceAbstractAction):
|
||||
"""Action which restarts a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "restart"
|
||||
|
||||
|
||||
class NodeServiceDisableAction(NodeServiceAbstractAction):
|
||||
"""Action which disables a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class NodeServiceEnableAction(NodeServiceAbstractAction):
|
||||
"""Action which enables a service."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NodeFolderAbstractAction(AbstractAction):
|
||||
"""
|
||||
Base class for folder actions.
|
||||
|
||||
Any action which applies to a folder and uses node_id and folder_id as its only two parameters can inherit from
|
||||
this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
if node_uuid is None or folder_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeFolderScanAction(NodeFolderAbstractAction):
|
||||
"""Action which scans a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "scan"
|
||||
|
||||
|
||||
class NodeFolderCheckhashAction(NodeFolderAbstractAction):
|
||||
"""Action which checks the hash of a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "checkhash"
|
||||
|
||||
|
||||
class NodeFolderRepairAction(NodeFolderAbstractAction):
|
||||
"""Action which repairs a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "repair"
|
||||
|
||||
|
||||
class NodeFolderRestoreAction(NodeFolderAbstractAction):
|
||||
"""Action which restores a folder."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
|
||||
self.verb: str = "restore"
|
||||
|
||||
|
||||
class NodeFileAbstractAction(AbstractAction):
|
||||
"""Abstract base class for file actions.
|
||||
|
||||
Any action which applies to a file and uses node_id, folder_id, and file_id as its only three parameters can inherit
|
||||
from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
folder_uuid = self.manager.get_folder_uuid_by_idx(node_idx=node_id, folder_idx=folder_id)
|
||||
file_uuid = self.manager.get_file_uuid_by_idx(node_idx=node_id, folder_idx=folder_id, file_idx=file_id)
|
||||
if node_uuid is None or folder_uuid is None or file_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return ["network", "node", node_uuid, "file_system", "folder", folder_uuid, "files", file_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeFileScanAction(NodeFileAbstractAction):
|
||||
"""Action which scans a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeFileCheckhashAction(NodeFileAbstractAction):
|
||||
"""Action which checks the hash of a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "checkhash"
|
||||
|
||||
|
||||
class NodeFileDeleteAction(NodeFileAbstractAction):
|
||||
"""Action which deletes a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "delete"
|
||||
|
||||
|
||||
class NodeFileRepairAction(NodeFileAbstractAction):
|
||||
"""Action which repairs a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "repair"
|
||||
|
||||
|
||||
class NodeFileRestoreAction(NodeFileAbstractAction):
|
||||
"""Action which restores a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "restore"
|
||||
|
||||
|
||||
class NodeFileCorruptAction(NodeFileAbstractAction):
|
||||
"""Action which corrupts a file."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
|
||||
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
|
||||
self.verb = "corrupt"
|
||||
|
||||
|
||||
class NodeAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for node actions.
|
||||
|
||||
Any action which applies to a node and uses node_id as its only parameter can inherit from this base class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
|
||||
return ["network", "node", node_uuid, self.verb]
|
||||
|
||||
|
||||
class NodeOSScanAction(NodeAbstractAction):
|
||||
"""Action which scans a node's OS."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "scan"
|
||||
|
||||
|
||||
class NodeShutdownAction(NodeAbstractAction):
|
||||
"""Action which shuts down a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "shutdown"
|
||||
|
||||
|
||||
class NodeStartupAction(NodeAbstractAction):
|
||||
"""Action which starts up a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "startup"
|
||||
|
||||
|
||||
class NodeResetAction(NodeAbstractAction):
|
||||
"""Action which resets a node."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes)
|
||||
self.verb = "reset"
|
||||
|
||||
|
||||
class NetworkACLAddRuleAction(AbstractAction):
|
||||
"""Action which adds a rule to a router's ACL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: "ActionManager",
|
||||
target_router_uuid: str,
|
||||
max_acl_rules: int,
|
||||
num_ips: int,
|
||||
num_ports: int,
|
||||
num_protocols: int,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init method for NetworkACLAddRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_uuid: UUID of the router to which the ACL rule should be added.
|
||||
:type target_router_uuid: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
:param num_ips: Number of IP addresses in the simulation.
|
||||
:type num_ips: int
|
||||
:param num_ports: Number of ports in the simulation.
|
||||
:type num_ports: int
|
||||
:param num_protocols: Number of protocols in the simulation.
|
||||
:type num_protocols: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
num_permissions = 3
|
||||
self.shape: Dict[str, int] = {
|
||||
"position": max_acl_rules,
|
||||
"permission": num_permissions,
|
||||
"source_ip_id": num_ips,
|
||||
"dest_ip_id": num_ips,
|
||||
"source_port_id": num_ports,
|
||||
"dest_port_id": num_ports,
|
||||
"protocol_id": num_protocols,
|
||||
}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(
|
||||
self,
|
||||
position: int,
|
||||
permission: int,
|
||||
source_ip_id: int,
|
||||
dest_ip_id: int,
|
||||
source_port_id: int,
|
||||
dest_port_id: int,
|
||||
protocol_id: int,
|
||||
) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
if permission == 0:
|
||||
permission_str = "UNUSED"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
elif permission == 1:
|
||||
permission_str = "ALLOW"
|
||||
elif permission == 2:
|
||||
permission_str = "DENY"
|
||||
else:
|
||||
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
|
||||
|
||||
if protocol_id == 0:
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
|
||||
if protocol_id == 1:
|
||||
protocol = "ALL"
|
||||
else:
|
||||
protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2)
|
||||
# subtract 2 to account for UNUSED=0 and ALL=1.
|
||||
|
||||
if source_ip_id in [0, 1]:
|
||||
src_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if source_port_id == 1:
|
||||
src_port = "ALL"
|
||||
else:
|
||||
src_port = self.manager.get_port_by_idx(source_port_id - 2)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_ip_id in (0, 1):
|
||||
dst_ip = "ALL"
|
||||
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
|
||||
else:
|
||||
dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
if dest_port_id == 1:
|
||||
dst_port = "ALL"
|
||||
else:
|
||||
dst_port = self.manager.get_port_by_idx(dest_port_id)
|
||||
# subtract 2 to account for UNUSED=0, and ALL=1
|
||||
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
self.target_router_uuid,
|
||||
"acl",
|
||||
"add_rule",
|
||||
permission_str,
|
||||
protocol,
|
||||
src_ip,
|
||||
src_port,
|
||||
dst_ip,
|
||||
dst_port,
|
||||
position,
|
||||
]
|
||||
|
||||
|
||||
class NetworkACLRemoveRuleAction(AbstractAction):
|
||||
"""Action which removes a rule from a router's ACL."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", target_router_uuid: str, max_acl_rules: int, **kwargs) -> None:
|
||||
"""Init method for NetworkACLRemoveRuleAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param target_router_uuid: UUID of the router from which the ACL rule should be removed.
|
||||
:type target_router_uuid: str
|
||||
:param max_acl_rules: Maximum number of ACL rules that can be added to the router.
|
||||
:type max_acl_rules: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"position": max_acl_rules}
|
||||
self.target_router_uuid: str = target_router_uuid
|
||||
|
||||
def form_request(self, position: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
return ["network", "node", self.target_router_uuid, "acl", "remove_rule", position]
|
||||
|
||||
|
||||
class NetworkNICAbstractAction(AbstractAction):
|
||||
"""
|
||||
Abstract base class for NIC actions.
|
||||
|
||||
Any action which applies to a NIC and uses node_id and nic_id as its only two parameters can inherit from this base
|
||||
class.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
"""Init method for NetworkNICAbstractAction.
|
||||
|
||||
:param manager: Reference to the ActionManager which created this action.
|
||||
:type manager: ActionManager
|
||||
:param num_nodes: Number of nodes in the simulation.
|
||||
:type num_nodes: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node.
|
||||
:type max_nics_per_node: int
|
||||
"""
|
||||
super().__init__(manager=manager)
|
||||
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
|
||||
self.verb: str
|
||||
|
||||
def form_request(self, node_id: int, nic_id: int) -> List[str]:
|
||||
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
|
||||
node_uuid = self.manager.get_node_uuid_by_idx(node_idx=node_id)
|
||||
nic_uuid = self.manager.get_nic_uuid_by_idx(node_idx=node_id, nic_idx=nic_id)
|
||||
if node_uuid is None or nic_uuid is None:
|
||||
return ["do_nothing"]
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
node_uuid,
|
||||
"nic",
|
||||
nic_uuid,
|
||||
self.verb,
|
||||
]
|
||||
|
||||
|
||||
class NetworkNICEnableAction(NetworkNICAbstractAction):
|
||||
"""Action which enables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "enable"
|
||||
|
||||
|
||||
class NetworkNICDisableAction(NetworkNICAbstractAction):
|
||||
"""Action which disables a NIC."""
|
||||
|
||||
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
|
||||
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
|
||||
self.verb = "disable"
|
||||
|
||||
|
||||
class ActionManager:
|
||||
"""Class which manages the action space for an agent."""
|
||||
|
||||
__act_class_identifiers: Dict[str, type] = {
|
||||
"DONOTHING": DoNothingAction,
|
||||
"NODE_SERVICE_SCAN": NodeServiceScanAction,
|
||||
"NODE_SERVICE_STOP": NodeServiceStopAction,
|
||||
"NODE_SERVICE_START": NodeServiceStartAction,
|
||||
"NODE_SERVICE_PAUSE": NodeServicePauseAction,
|
||||
"NODE_SERVICE_RESUME": NodeServiceResumeAction,
|
||||
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
|
||||
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
|
||||
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
|
||||
"NODE_FILE_SCAN": NodeFileScanAction,
|
||||
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
|
||||
"NODE_FILE_DELETE": NodeFileDeleteAction,
|
||||
"NODE_FILE_REPAIR": NodeFileRepairAction,
|
||||
"NODE_FILE_RESTORE": NodeFileRestoreAction,
|
||||
"NODE_FILE_CORRUPT": NodeFileCorruptAction,
|
||||
"NODE_FOLDER_SCAN": NodeFolderScanAction,
|
||||
"NODE_FOLDER_CHECKHASH": NodeFolderCheckhashAction,
|
||||
"NODE_FOLDER_REPAIR": NodeFolderRepairAction,
|
||||
"NODE_FOLDER_RESTORE": NodeFolderRestoreAction,
|
||||
"NODE_OS_SCAN": NodeOSScanAction,
|
||||
"NODE_SHUTDOWN": NodeShutdownAction,
|
||||
"NODE_STARTUP": NodeStartupAction,
|
||||
"NODE_RESET": NodeResetAction,
|
||||
"NETWORK_ACL_ADDRULE": NetworkACLAddRuleAction,
|
||||
"NETWORK_ACL_REMOVERULE": NetworkACLRemoveRuleAction,
|
||||
"NETWORK_NIC_ENABLE": NetworkNICEnableAction,
|
||||
"NETWORK_NIC_DISABLE": NetworkNICDisableAction,
|
||||
}
|
||||
"""Dictionary which maps action type strings to the corresponding action class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: "PrimaiteSession", # reference to session for looking up stuff
|
||||
actions: List[str], # stores list of actions available to agent
|
||||
node_uuids: List[str], # allows mapping index to node
|
||||
max_folders_per_node: int = 2, # allows calculating shape
|
||||
max_files_per_folder: int = 2, # allows calculating shape
|
||||
max_services_per_node: int = 2, # allows calculating shape
|
||||
max_nics_per_node: int = 8, # allows calculating shape
|
||||
max_acl_rules: int = 10, # allows calculating shape
|
||||
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
|
||||
ports: List[str] = ["HTTP", "DNS", "ARP", "FTP"], # allow mapping index to port
|
||||
ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address.
|
||||
act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions
|
||||
) -> None:
|
||||
"""Init method for ActionManager.
|
||||
|
||||
:param session: Reference to the session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param actions: List of action types which should be made available to the agent.
|
||||
:type actions: List[str]
|
||||
:param node_uuids: List of node UUIDs that this agent can act on.
|
||||
:type node_uuids: List[str]
|
||||
:param max_folders_per_node: Maximum number of folders per node. Used for calculating action shape.
|
||||
:type max_folders_per_node: int
|
||||
:param max_files_per_folder: Maximum number of files per folder. Used for calculating action shape.
|
||||
:type max_files_per_folder: int
|
||||
:param max_services_per_node: Maximum number of services per node. Used for calculating action shape.
|
||||
:type max_services_per_node: int
|
||||
:param max_nics_per_node: Maximum number of NICs per node. Used for calculating action shape.
|
||||
:type max_nics_per_node: int
|
||||
:param max_acl_rules: Maximum number of ACL rules per router. Used for calculating action shape.
|
||||
:type max_acl_rules: int
|
||||
:param protocols: List of protocols that are available in the simulation. Used for calculating action shape.
|
||||
:type protocols: List[str]
|
||||
:param ports: List of ports that are available in the simulation. Used for calculating action shape.
|
||||
:type ports: List[str]
|
||||
:param ip_address_list: List of IP addresses that known to this agent. Used for calculating action shape.
|
||||
:type ip_address_list: Optional[List[str]]
|
||||
:param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions.
|
||||
:type act_map: Optional[Dict[int, Dict]]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.sim: Simulation = self.session.simulation
|
||||
self.node_uuids: List[str] = node_uuids
|
||||
self.protocols: List[str] = protocols
|
||||
self.ports: List[str] = ports
|
||||
|
||||
self.ip_address_list: List[str]
|
||||
if ip_address_list is not None:
|
||||
self.ip_address_list = ip_address_list
|
||||
else:
|
||||
self.ip_address_list = []
|
||||
for node_uuid in self.node_uuids:
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = node_obj.nics
|
||||
for nic_uuid, nic_obj in nics.items():
|
||||
self.ip_address_list.append(nic_obj.ip_address)
|
||||
|
||||
# action_args are settings which are applied to the action space as a whole.
|
||||
global_action_args = {
|
||||
"num_nodes": len(node_uuids),
|
||||
"num_folders": max_folders_per_node,
|
||||
"num_files": max_files_per_folder,
|
||||
"num_services": max_services_per_node,
|
||||
"num_nics": max_nics_per_node,
|
||||
"num_acl_rules": max_acl_rules,
|
||||
"num_protocols": len(self.protocols),
|
||||
"num_ports": len(self.protocols),
|
||||
"num_ips": len(self.ip_address_list),
|
||||
"max_acl_rules": max_acl_rules,
|
||||
"max_nics_per_node": max_nics_per_node,
|
||||
}
|
||||
self.actions: Dict[str, AbstractAction] = {}
|
||||
for act_spec in actions:
|
||||
# each action is provided into the action space config like this:
|
||||
# - type: ACTION_TYPE
|
||||
# options:
|
||||
# option_1: value1
|
||||
# option_2: value2
|
||||
# where `type` decides which AbstractAction subclass should be used
|
||||
# and `options` is an optional dict of options to pass to the init method of the action class
|
||||
act_type = act_spec.get("type")
|
||||
act_options = act_spec.get("options", {})
|
||||
self.actions[act_type] = self.__act_class_identifiers[act_type](self, **global_action_args, **act_options)
|
||||
|
||||
self.action_map: Dict[int, Tuple[str, Dict]] = {}
|
||||
"""
|
||||
Action mapping that converts an integer to a specific action and parameter choice.
|
||||
|
||||
For example :
|
||||
{0: ("NODE_SERVICE_SCAN", {node_id:0, service_id:2})}
|
||||
"""
|
||||
if act_map is None:
|
||||
self.action_map = self._enumerate_actions()
|
||||
else:
|
||||
self.action_map = {i: (a["action"], a["options"]) for i, a in act_map.items()}
|
||||
# make sure all numbers between 0 and N are represented as dict keys in action map
|
||||
assert all([i in self.action_map.keys() for i in range(len(self.action_map))])
|
||||
|
||||
def _enumerate_actions(
|
||||
self,
|
||||
) -> Dict[int, Tuple[str, Dict]]:
|
||||
"""Generate a list of all the possible actions that could be taken.
|
||||
|
||||
This enumerates all actions all combinations of parametes you could choose for those actions. The output
|
||||
of this function is intended to populate the self.action_map parameter in the situation where the user provides
|
||||
a list of action types, but doesn't specify any subset of actions that should be made available to the agent.
|
||||
|
||||
The enumeration relies on the Actions' `shape` attribute.
|
||||
|
||||
:return: An action map maps consecutive integers to a combination of Action type and parameter choices.
|
||||
An example output could be:
|
||||
{0: ("DONOTHING", {'dummy': 0}),
|
||||
1: ("NODE_OS_SCAN", {'node_id': 0}),
|
||||
2: ("NODE_OS_SCAN", {'node_id': 1}),
|
||||
3: ("NODE_FOLDER_SCAN", {'node_id:0, folder_id:0}),
|
||||
... #etc...
|
||||
}
|
||||
:rtype: Dict[int, Tuple[AbstractAction, Dict]]
|
||||
"""
|
||||
all_action_possibilities = []
|
||||
for act_name, action in self.actions.items():
|
||||
param_names = list(action.shape.keys())
|
||||
num_possibilities = list(action.shape.values())
|
||||
possibilities = [range(n) for n in num_possibilities]
|
||||
|
||||
param_combinations = list(itertools.product(*possibilities))
|
||||
all_action_possibilities.extend(
|
||||
[
|
||||
(act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))})
|
||||
for j in range(len(param_combinations))
|
||||
]
|
||||
)
|
||||
|
||||
return {i: p for i, p in enumerate(all_action_possibilities)}
|
||||
|
||||
def get_action(self, action: int) -> Tuple[str, Dict]:
|
||||
"""Produce action in CAOS format."""
|
||||
"""the agent chooses an action (as an integer), this is converted into an action in CAOS format"""
|
||||
"""The CAOS format is basically a action identifier, followed by parameters stored in a dictionary"""
|
||||
act_identifier, act_options = self.action_map[action]
|
||||
return act_identifier, act_options
|
||||
|
||||
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
|
||||
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
|
||||
act_obj = self.actions[action_identifier]
|
||||
return act_obj.form_request(**action_options)
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Return the gymnasium action space for this agent."""
|
||||
return spaces.Discrete(len(self.action_map))
|
||||
|
||||
def get_node_uuid_by_idx(self, node_idx: int) -> str:
|
||||
"""
|
||||
Get the node UUID corresponding to the given index.
|
||||
|
||||
:param node_idx: The index of the node to retrieve.
|
||||
:type node_idx: int
|
||||
:return: The node UUID.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.node_uuids[node_idx]
|
||||
|
||||
def get_folder_uuid_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]:
|
||||
"""
|
||||
Get the folder UUID corresponding to the given node and folder indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:return: The UUID of the folder. Or None if the node has fewer folders than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
|
||||
|
||||
def get_file_uuid_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]:
|
||||
"""Get the file UUID corresponding to the given node, folder, and file indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param folder_idx: The index of the folder on the node.
|
||||
:type folder_idx: int
|
||||
:param file_idx: The index of the file in the folder.
|
||||
:type file_idx: int
|
||||
:return: The UUID of the file. Or None if the node has fewer folders than the given index, or the folder has
|
||||
fewer files than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
folder_uuids = list(node.file_system.folders.keys())
|
||||
if len(folder_uuids) <= folder_idx:
|
||||
return None
|
||||
folder = node.file_system.folders[folder_uuids[folder_idx]]
|
||||
file_uuids = list(folder.files.keys())
|
||||
return file_uuids[file_idx] if len(file_uuids) > file_idx else None
|
||||
|
||||
def get_service_uuid_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]:
|
||||
"""Get the service UUID corresponding to the given node and service indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param service_idx: The index of the service on the node.
|
||||
:type service_idx: int
|
||||
:return: The UUID of the service. Or None if the node has fewer services than the given index.
|
||||
:rtype: Optional[str]
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node = self.sim.network.nodes[node_uuid]
|
||||
service_uuids = list(node.services.keys())
|
||||
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
|
||||
|
||||
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
|
||||
"""Get the internet protocol corresponding to the given index.
|
||||
|
||||
:param protocol_idx: The index of the protocol to retrieve.
|
||||
:type protocol_idx: int
|
||||
:return: The protocol.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.protocols[protocol_idx]
|
||||
|
||||
def get_ip_address_by_idx(self, ip_idx: int) -> str:
|
||||
"""
|
||||
Get the IP address corresponding to the given index.
|
||||
|
||||
:param ip_idx: The index of the IP address to retrieve.
|
||||
:type ip_idx: int
|
||||
:return: The IP address.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.ip_address_list[ip_idx]
|
||||
|
||||
def get_port_by_idx(self, port_idx: int) -> str:
|
||||
"""
|
||||
Get the port corresponding to the given index.
|
||||
|
||||
:param port_idx: The index of the port to retrieve.
|
||||
:type port_idx: int
|
||||
:return: The port.
|
||||
:rtype: str
|
||||
"""
|
||||
return self.ports[port_idx]
|
||||
|
||||
def get_nic_uuid_by_idx(self, node_idx: int, nic_idx: int) -> str:
|
||||
"""
|
||||
Get the NIC UUID corresponding to the given node and NIC indices.
|
||||
|
||||
:param node_idx: The index of the node.
|
||||
:type node_idx: int
|
||||
:param nic_idx: The index of the NIC on the node.
|
||||
:type nic_idx: int
|
||||
:return: The NIC UUID.
|
||||
:rtype: str
|
||||
"""
|
||||
node_uuid = self.get_node_uuid_by_idx(node_idx)
|
||||
node_obj = self.sim.network.nodes[node_uuid]
|
||||
nics = list(node_obj.nics.keys())
|
||||
if len(nics) <= nic_idx:
|
||||
return None
|
||||
return nics[nic_idx]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, session: "PrimaiteSession", cfg: Dict) -> "ActionManager":
|
||||
"""
|
||||
Construct an ActionManager from a config definition.
|
||||
|
||||
The action space config supports the following three sections:
|
||||
1. ``action_list``
|
||||
``action_list`` contians a list action components which need to be included in the action space.
|
||||
Each action component has a ``type`` which maps to a subclass of AbstractAction, and additional options
|
||||
which will be passed to the action class's __init__ method during initialisation.
|
||||
2. ``action_map``
|
||||
Since the agent uses a discrete action space which acts as a flattened version of the component-based
|
||||
action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful
|
||||
action and values of parameters. For example action 0 can correspond to do nothing, action 1 can
|
||||
correspond to "NODE_SERVICE_SCAN" with ``node_id=1`` and ``service_id=1``, action 2 can be "
|
||||
3. ``options``
|
||||
``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method.
|
||||
These options are used to calculate the shape of the action space, and to provide additional information
|
||||
to the ActionManager which is required to convert the agent's action choice into a CAOS request.
|
||||
|
||||
:param session: The Primaite Session to which the agent belongs.
|
||||
:type session: PrimaiteSession
|
||||
:param cfg: The action space config.
|
||||
:type cfg: Dict
|
||||
:return: The constructed ActionManager.
|
||||
:rtype: ActionManager
|
||||
"""
|
||||
obj = cls(
|
||||
session=session,
|
||||
actions=cfg["action_list"],
|
||||
# node_uuids=cfg["options"]["node_uuids"],
|
||||
**cfg["options"],
|
||||
protocols=session.options.protocols,
|
||||
ports=session.options.ports,
|
||||
ip_address_list=None,
|
||||
act_map=cfg.get("action_map"),
|
||||
)
|
||||
|
||||
return obj
|
||||
116
src/primaite/game/agent/interface.py
Normal file
116
src/primaite/game/agent/interface.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.primaite.game.agent.actions import ActionManager
|
||||
from src.primaite.game.agent.observations import ObservationSpace
|
||||
from src.primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an agent.
|
||||
|
||||
:param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes.
|
||||
:type agent_name: Optional[str]
|
||||
:param action_space: Action space for the agent.
|
||||
:type action_space: Optional[ActionManager]
|
||||
:param observation_space: Observation space for the agent.
|
||||
:type observation_space: Optional[ObservationSpace]
|
||||
:param reward_function: Reward function for the agent.
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
|
||||
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
"""
|
||||
Use the reward function to calculate a reward from the state.
|
||||
|
||||
:param state: State of the environment.
|
||||
:type state: Dict
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.calculate(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
Subclasses should implement agent logic here. It should use the observation as input to decide best next action.
|
||||
|
||||
:param obs: Observation of the environment.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted?
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]:
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
:type obs: ObsType
|
||||
:param reward: _description_, defaults to None
|
||||
:type reward: float, optional
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
|
||||
...
|
||||
984
src/primaite/game/agent/observations.py
Normal file
984
src/primaite/game/agent/observations.py
Normal file
@@ -0,0 +1,984 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from src.primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Return an observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
|
||||
a subclass of this class may need to translate from a 'reference' to a UUID.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
"""Observation of a file on a node in the network."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""
|
||||
Initialise file observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["health_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param session: _description_
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
"""Observation of a service in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise service observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_state"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 10) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": utilisation_category}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
"""Folder observation, including files inside of the folder."""
|
||||
|
||||
def __init__(
|
||||
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
|
||||
) -> None:
|
||||
"""Initialise folder Observation, including files inside of the folder.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files_per_folder:
|
||||
self.files.append(FileObservation())
|
||||
while len(self.files) > num_files_per_folder:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_uuid>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed folder observation
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"nic_status": 0}
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
return {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]
|
||||
) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
nics: List[NicObservation] = [],
|
||||
logon_status: bool = False,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service UUID, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
|
||||
:type nics: Dict[int,str], optional
|
||||
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services_per_node:
|
||||
# add empty service observation without `where` parameter so it always returns default (blank) observation
|
||||
self.services.append(ServiceObservation())
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warn(msg)
|
||||
# truncate service list
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.nics: List[NicObservation] = nics
|
||||
while len(self.nics) < num_nics_per_node:
|
||||
self.nics.append(NicObservation())
|
||||
while len(self.nics) > num_nics_per_node:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warn(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
|
||||
if self.logon_status:
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
session: "PrimaiteSession",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_services_per_node: int, optional
|
||||
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_folders_per_node: int, optional
|
||||
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_uuid = session.ref_map_nodes[config["node_ref"]]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_uuid]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_uuid]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
|
||||
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
|
||||
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
folders=folders,
|
||||
nics=nics,
|
||||
logon_status=logon_status,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_uuid>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
if rule_state is None:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_ref"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.ethernet_port[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=session.options.ports,
|
||||
protocols=session.options.protocols,
|
||||
where=["network", "nodes", router_uuid, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
"""Container for all observations used by the blue agent in UC2.
|
||||
|
||||
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
|
||||
for the purpose of compiling several observation components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialise UC2 blue observation.
|
||||
|
||||
:param nodes: List of node observations
|
||||
:type nodes: List[NodeObservation]
|
||||
:param links: List of link observations
|
||||
:type links: List[LinkObservation]
|
||||
:param acl: The Access Control List observation
|
||||
:type acl: AclObservation
|
||||
:param ics: The ICS observation
|
||||
:type ics: ICSObservation
|
||||
:param where: Where in the simulation state dict to find information. Not used in this particular observation
|
||||
because it only compiles other observations and doesn't contribute any new information, defaults to None
|
||||
:type where: Optional[List[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
session=session,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, session=session)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, session=session)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
"""Container for all observations used by the red agent in UC2."""
|
||||
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation):
|
||||
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
The observation space has the purpose of:
|
||||
1. Reading the outputted state from the PrimAITE Simulation.
|
||||
2. Selecting parts of the simulation state that are requested by the simulation config
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
|
||||
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
|
||||
# refactor.
|
||||
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
"""
|
||||
return self.obs.observe(state)
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimaiteSession object that spawned this observation.
|
||||
:type session: PrimaiteSession
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
284
src/primaite/game/agent/rewards.py
Normal file
284
src/primaite/game/agent/rewards.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Manages the reward function for the agent.
|
||||
|
||||
Each agent is equipped with a RewardFunction, which is made up of a list of reward components. The components are
|
||||
designed to calculate a reward value based on the current state of the simulation. The overall reward function is a
|
||||
weighed sum of the components.
|
||||
|
||||
The reward function is typically specified using a config yaml file or a config dictionary. The following example shows
|
||||
the structure:
|
||||
```yaml
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_database_client
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from src.primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.primaite.game.session import PrimaiteSession
|
||||
|
||||
|
||||
class AbstractReward:
|
||||
"""Base class for reward function components."""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "AbstractReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward component.
|
||||
:rtype: AbstractReward
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DummyReward(AbstractReward):
|
||||
"""Dummy reward function component which always returns 0."""
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state."""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict, session: "PrimaiteSession") -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
:type config: dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
def __init__(self, node_uuid: str, folder_name: str, file_name: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_uuid: UUID of the node which contains the database file.
|
||||
:type node_uuid: str
|
||||
:param folder_name: folder which contains the database file.
|
||||
:type folder_name: str
|
||||
:param file_name: name of the database file.
|
||||
:type file_name: str
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
node_uuid,
|
||||
"file_system",
|
||||
"folders",
|
||||
folder_name,
|
||||
"files",
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
health_status = database_file_state["health_status"]
|
||||
if health_status == "corrupted":
|
||||
return -1
|
||||
elif health_status == "good":
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
node_ref = config.get("node_ref")
|
||||
folder_name = config.get("folder_name")
|
||||
file_name = config.get("file_name")
|
||||
if not node_ref:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because node_ref parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not folder_name:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because folder_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
if not file_name:
|
||||
_LOGGER.error(
|
||||
f"{cls.__name__} could not be initialised from config because file_name parameter was not specified"
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
if not node_uuid:
|
||||
_LOGGER.error(
|
||||
(
|
||||
f"{cls.__name__} could not be initialised from config because the referenced node could not be "
|
||||
f"found in the simulation"
|
||||
)
|
||||
)
|
||||
return DummyReward() # TODO: better error handling
|
||||
|
||||
return cls(node_uuid=node_uuid, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
def __init__(self, node_uuid: str, service_uuid: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_uuid: UUID of the node which contains the web server service.
|
||||
:type node_uuid: str
|
||||
:param service_uuid: UUID of the web server service.
|
||||
:type service_uuid: str
|
||||
"""
|
||||
self.location_in_state = ["network", "nodes", node_uuid, "services", service_uuid]
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if web_service_state is NOT_PRESENT_IN_STATE:
|
||||
print("error getting web service state")
|
||||
return 0.0
|
||||
most_recent_return_code = web_service_state["last_response_status_code"]
|
||||
# TODO: reward needs to use the current web state. Observation should return web state at the time of last scan.
|
||||
if most_recent_return_code == 200:
|
||||
return 1.0
|
||||
elif most_recent_return_code == 404:
|
||||
return -1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
node_ref = config.get("node_ref")
|
||||
service_ref = config.get("service_ref")
|
||||
if not (node_ref and service_ref):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
|
||||
"found in reward config."
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
|
||||
node_uuid = session.ref_map_nodes[node_ref]
|
||||
service_uuid = session.ref_map_services[service_ref].uuid
|
||||
if not (node_uuid and service_uuid):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
|
||||
" found in the simulator."
|
||||
)
|
||||
_LOGGER.warn(msg)
|
||||
return DummyReward() # TODO: consider erroring here as well
|
||||
|
||||
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
|
||||
:param component: Instance of a reward component.
|
||||
:type component: AbstractReward
|
||||
:param weight: Relative weight of the reward component, defaults to 1.0
|
||||
:type weight: float, optional
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
:type state: Dict
|
||||
"""
|
||||
total = 0.0
|
||||
for comp_and_weight in self.reward_components:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
return total
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
"""Create a reward function from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward manager's constructor
|
||||
:type config: Dict
|
||||
:param session: Reference to the PrimAITE Session object
|
||||
:type session: PrimaiteSession
|
||||
:return: The reward manager.
|
||||
:rtype: RewardFunction
|
||||
"""
|
||||
new = cls()
|
||||
|
||||
for rew_component_cfg in config["reward_components"]:
|
||||
rew_type = rew_component_cfg["type"]
|
||||
weight = rew_component_cfg.get("weight", 1.0)
|
||||
rew_class = cls.__rew_class_identifiers[rew_type]
|
||||
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}), session=session)
|
||||
new.regsiter_component(component=rew_instance, weight=weight)
|
||||
return new
|
||||
14
src/primaite/game/agent/scripted_agents.py
Normal file
14
src/primaite/game/agent/scripted_agents.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Agents with predefined behaviours."""
|
||||
from src.primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class GreenWebBrowsingAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to send web requests to a target node."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedDatabaseCorruptingAgent(AbstractScriptedAgent):
|
||||
"""Scripted agent which attempts to corrupt the database of the target node."""
|
||||
|
||||
raise NotImplementedError
|
||||
30
src/primaite/game/agent/utils.py
Normal file
30
src/primaite/game/agent/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any, Dict, Hashable, Sequence
|
||||
|
||||
NOT_PRESENT_IN_STATE = object()
|
||||
"""
|
||||
Need an object to return when the sim state does not contain a requested value. Cannot use None because sometimes
|
||||
the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is a sentinel for this purpose.
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
"""
|
||||
Access an item from a deeply dictionary with a list of keys.
|
||||
|
||||
For example, if the dictionary is {1: 'a', 2: {3: {4: 'b'}}}, then the key [2, 3, 4] would return 'b', and the key
|
||||
[2, 3] would return {4: 'b'}. Raises a KeyError if specified key does not exist at any level of nesting.
|
||||
|
||||
:param dictionary: Deeply nested dictionary
|
||||
:type dictionary: Dict
|
||||
:param keys: List of dict keys used to traverse the nested dict. Each item corresponds to one level of depth.
|
||||
:type keys: List[Hashable]
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
k = key_list.pop(0)
|
||||
if k not in dictionary:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
return access_from_nested_dict(dictionary[k], key_list)
|
||||
471
src/primaite/game/session.py
Normal file
471
src/primaite/game/session.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from arcd_gate.client.gate_client import ActType, GATEClient
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces.utils import flatten, flatten_space
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from src.primaite.game.agent.actions import ActionManager
|
||||
from src.primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from src.primaite.game.agent.observations import ObservationSpace
|
||||
from src.primaite.game.agent.rewards import RewardFunction
|
||||
from src.primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from src.primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from src.primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from src.primaite.simulator.network.hardware.nodes.server import Server
|
||||
from src.primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.sim_container import Simulation
|
||||
from src.primaite.simulator.system.applications.application import Application
|
||||
from src.primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from src.primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from src.primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from src.primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from src.primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from src.primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from src.primaite.simulator.system.services.service import Service
|
||||
from src.primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
|
||||
|
||||
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
|
||||
"""
|
||||
Create a new GATE client for PrimAITE.
|
||||
|
||||
:param parent_session: The parent session object.
|
||||
:type parent_session: PrimaiteSession
|
||||
:param service_port: The port on which the GATE service is running.
|
||||
:type service_port: int, optional
|
||||
"""
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session: "PrimaiteSession" = parent_session
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
"""The reinforcement learning framework to use."""
|
||||
return self.parent_session.training_options.rl_framework
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
"""The reinforcement learning algorithm to use."""
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> int | None:
|
||||
"""The seed to use for the environment's random number generator."""
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
"""The number of episodes in each learning run."""
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
"""The number of steps in each learning episode."""
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
"""The number of episodes in each evaluation run."""
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
"""The number of steps in each evaluation episode."""
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
"""The gym action space of the agent."""
|
||||
return self.parent_session.rl_agent.action_space.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""The gymnasium observation space of the agent."""
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
"""Take a step in the environment.
|
||||
|
||||
This method is called by GATE to advance the simulation by one timestep.
|
||||
|
||||
:param action: The agent's action.
|
||||
:type action: ActType
|
||||
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
|
||||
:rtype: Tuple[ObsType, float, bool, bool, Dict]
|
||||
"""
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate(state)
|
||||
term = False
|
||||
trunc = False
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment.
|
||||
|
||||
This method is called when the environment is initialized and at the end of each episode.
|
||||
|
||||
:param seed: The seed to use for the environment's random number generator.
|
||||
:type seed: int, optional
|
||||
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
|
||||
compatibility with GATE.
|
||||
:type options: dict[str, Any], optional
|
||||
:return: The initial observation and an empty info dictionary.
|
||||
:rtype: Tuple[ObsType, Dict]
|
||||
"""
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
return obs, {}
|
||||
|
||||
def close(self):
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
self.parent_session.close()
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
"""
|
||||
Global options which are applicable to all of the agents in the game.
|
||||
|
||||
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
|
||||
"""
|
||||
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
rl_framework: str
|
||||
rl_algorithm: str
|
||||
seed: Optional[int]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int
|
||||
n_eval_steps: int
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
|
||||
|
||||
def __init__(self):
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.rl_agent: AbstractAgent
|
||||
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.options: PrimaiteSessionOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
self.ref_map_applications: Dict[str, Application] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
self.gate_client.start()
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
|
||||
This is the main loop of the game. It corresponds to one timestep in the simulation, and one action from each
|
||||
agent. The steps are as follows:
|
||||
1. The simulation state is updated.
|
||||
2. The simulation state is sent to each agent.
|
||||
3. Each agent converts the state to an observation and calculates a reward.
|
||||
4. Each agent chooses an action based on the observation.
|
||||
5. Each agent converts the action to a request.
|
||||
6. The simulation applies the requests.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
# 5. primaite session asks agents to produce an action based on most recent state
|
||||
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
|
||||
sim_state = self.simulation.describe_state()
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.convert_state_to_obs(sim_state)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.calculate_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
_LOGGER.debug("Getting agent action")
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
|
||||
self.simulation.apply_request(agent_request)
|
||||
|
||||
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
self.step_counter += 1
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent. Used by GATE.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
# TODO: create documentation page and add links to it here.
|
||||
|
||||
:param cfg: The config dictionary.
|
||||
:type cfg: dict
|
||||
:return: A PrimaiteSession object.
|
||||
:rtype: PrimaiteSession
|
||||
"""
|
||||
sess = cls()
|
||||
sess.options = PrimaiteSessionOptions(
|
||||
ports=cfg["game_config"]["ports"],
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
sess.ref_map_nodes: Dict[str, Node] = {}
|
||||
sess.ref_map_services: Dict[str, Service] = {}
|
||||
sess.ref_map_links: Dict[str, Link] = {}
|
||||
|
||||
nodes_cfg = cfg["simulation"]["network"]["nodes"]
|
||||
links_cfg = cfg["simulation"]["network"]["links"]
|
||||
for node_cfg in nodes_cfg:
|
||||
node_ref = node_cfg["ref"]
|
||||
n_type = node_cfg["type"]
|
||||
if n_type == "computer":
|
||||
new_node = Computer(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
default_gateway=node_cfg["default_gateway"],
|
||||
dns_server=node_cfg["dns_server"],
|
||||
)
|
||||
elif n_type == "server":
|
||||
new_node = Server(
|
||||
hostname=node_cfg["hostname"],
|
||||
ip_address=node_cfg["ip_address"],
|
||||
subnet_mask=node_cfg["subnet_mask"],
|
||||
default_gateway=node_cfg["default_gateway"],
|
||||
dns_server=node_cfg.get("dns_server"),
|
||||
)
|
||||
elif n_type == "switch":
|
||||
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
|
||||
elif n_type == "router":
|
||||
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
|
||||
if "ports" in node_cfg:
|
||||
for port_num, port_cfg in node_cfg["ports"].items():
|
||||
new_node.configure_port(
|
||||
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
|
||||
)
|
||||
if "acl" in node_cfg:
|
||||
for r_num, r_cfg in node_cfg["acl"].items():
|
||||
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
|
||||
# this: 'r_cfg.get('src_port')'
|
||||
# Port/IPProtocol. TODO Refactor
|
||||
new_node.acl.add_rule(
|
||||
action=ACLAction[r_cfg["action"]],
|
||||
src_port=None if not (p := r_cfg.get("src_port")) else Port[p],
|
||||
dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p],
|
||||
protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p],
|
||||
src_ip_address=r_cfg.get("ip_address"),
|
||||
dst_ip_address=r_cfg.get("ip_address"),
|
||||
position=r_num,
|
||||
)
|
||||
else:
|
||||
print("invalid node type")
|
||||
if "services" in node_cfg:
|
||||
for service_cfg in node_cfg["services"]:
|
||||
service_ref = service_cfg["ref"]
|
||||
service_type = service_cfg["type"]
|
||||
service_types_mapping = {
|
||||
"DNSClient": DNSClient, # key is equal to the 'name' attr of the service class itself.
|
||||
"DNSServer": DNSServer,
|
||||
"DatabaseClient": DatabaseClient,
|
||||
"DatabaseService": DatabaseService,
|
||||
"WebServer": WebServer,
|
||||
"DataManipulationBot": DataManipulationBot,
|
||||
}
|
||||
if service_type in service_types_mapping:
|
||||
print(f"installing {service_type} on node {new_node.hostname}")
|
||||
new_node.software_manager.install(service_types_mapping[service_type])
|
||||
new_service = new_node.software_manager.software[service_type]
|
||||
sess.ref_map_services[service_ref] = new_service
|
||||
else:
|
||||
print(f"service type not found {service_type}")
|
||||
# service-dependent options
|
||||
if service_type == "DatabaseClient":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "db_server_ip" in opt:
|
||||
new_service.configure(server_ip_address=IPv4Address(opt["db_server_ip"]))
|
||||
if service_type == "DNSServer":
|
||||
if "options" in service_cfg:
|
||||
opt = service_cfg["options"]
|
||||
if "domain_mapping" in opt:
|
||||
for domain, ip in opt["domain_mapping"].items():
|
||||
new_service.dns_register(domain, ip)
|
||||
if "applications" in node_cfg:
|
||||
for application_cfg in node_cfg["applications"]:
|
||||
application_ref = application_cfg["ref"]
|
||||
application_type = application_cfg["type"]
|
||||
application_types_mapping = {
|
||||
"WebBrowser": WebBrowser,
|
||||
}
|
||||
if application_type in application_types_mapping:
|
||||
new_node.software_manager.install(application_types_mapping[application_type])
|
||||
new_application = new_node.software_manager.software[application_type]
|
||||
sess.ref_map_applications[application_ref] = new_application
|
||||
else:
|
||||
print(f"application type not found {application_type}")
|
||||
if "nics" in node_cfg:
|
||||
for nic_num, nic_cfg in node_cfg["nics"].items():
|
||||
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
|
||||
|
||||
net.add_node(new_node)
|
||||
new_node.power_on()
|
||||
sess.ref_map_nodes[
|
||||
node_ref
|
||||
] = (
|
||||
new_node.uuid
|
||||
) # TODO: fix incosistency with service and link. Node gets added by uuid, but service by object
|
||||
|
||||
# 2. create links between nodes
|
||||
for link_cfg in links_cfg:
|
||||
node_a = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_a_ref"]]]
|
||||
node_b = net.nodes[sess.ref_map_nodes[link_cfg["endpoint_b_ref"]]]
|
||||
if isinstance(node_a, Switch):
|
||||
endpoint_a = node_a.switch_ports[link_cfg["endpoint_a_port"]]
|
||||
else:
|
||||
endpoint_a = node_a.ethernet_port[link_cfg["endpoint_a_port"]]
|
||||
if isinstance(node_b, Switch):
|
||||
endpoint_b = node_b.switch_ports[link_cfg["endpoint_b_port"]]
|
||||
else:
|
||||
endpoint_b = node_b.ethernet_port[link_cfg["endpoint_b_port"]]
|
||||
new_link = net.connect(endpoint_a=endpoint_a, endpoint_b=endpoint_b)
|
||||
sess.ref_map_links[link_cfg["ref"]] = new_link.uuid
|
||||
|
||||
# 3. create agents
|
||||
game_cfg = cfg["game_config"]
|
||||
agents_cfg = game_cfg["agents"]
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_ref = agent_cfg["ref"] # noqa: F841
|
||||
agent_type = agent_cfg["type"]
|
||||
action_space_cfg = agent_cfg["action_space"]
|
||||
observation_space_cfg = agent_cfg["observation_space"]
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
# if a list of nodes is defined, convert them from node references to node UUIDs
|
||||
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
|
||||
if "node_ref" in action_node_option:
|
||||
node_uuid = sess.ref_map_nodes[action_node_option["node_ref"]]
|
||||
action_space_cfg["options"]["node_uuids"].append(node_uuid)
|
||||
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
|
||||
# we will pass node_uuids as a part of the action space config.
|
||||
# However, it's not possible to specify the node uuids directly in the config, as they are generated
|
||||
# dynamically, so we have to translate node references to uuids before passing this config on.
|
||||
|
||||
if "action_list" in action_space_cfg:
|
||||
for action_config in action_space_cfg["action_list"]:
|
||||
if "options" in action_config:
|
||||
if "target_router_ref" in action_config["options"]:
|
||||
_target = action_config["options"]["target_router_ref"]
|
||||
action_config["options"]["target_router_uuid"] = sess.ref_map_nodes[_target]
|
||||
|
||||
action_space = ActionManager.from_config(sess, action_space_cfg)
|
||||
|
||||
# CREATE REWARD FUNCTION
|
||||
rew_function = RewardFunction.from_config(reward_function_cfg, session=sess)
|
||||
|
||||
# CREATE AGENT
|
||||
if agent_type == "GreenWebBrowsingAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
return sess
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Network connections between nodes in the simulation."""
|
||||
@@ -1,114 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The link class."""
|
||||
from typing import List
|
||||
|
||||
from primaite.common.protocol import Protocol
|
||||
|
||||
|
||||
class Link(object):
|
||||
"""Link class."""
|
||||
|
||||
def __init__(self, _id: str, _bandwidth: int, _source_node_name: str, _dest_node_name: str, _services: str) -> None:
|
||||
"""
|
||||
Initialise a Link within the simulated network.
|
||||
|
||||
:param _id: The IER id
|
||||
:param _bandwidth: The bandwidth of the link (bps)
|
||||
:param _source_node_name: The name of the source node
|
||||
:param _dest_node_name: The name of the destination node
|
||||
:param _protocols: The protocols to add to the link
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.bandwidth: int = _bandwidth
|
||||
self.source_node_name: str = _source_node_name
|
||||
self.dest_node_name: str = _dest_node_name
|
||||
self.protocol_list: List[Protocol] = []
|
||||
|
||||
# Add the default protocols
|
||||
for protocol_name in _services:
|
||||
self.add_protocol(protocol_name)
|
||||
|
||||
def add_protocol(self, _protocol: str) -> None:
|
||||
"""
|
||||
Adds a new protocol to the list of protocols on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to be added (enum)
|
||||
"""
|
||||
self.protocol_list.append(Protocol(_protocol))
|
||||
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets link ID.
|
||||
|
||||
Returns:
|
||||
Link ID
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_source_node_name(self) -> str:
|
||||
"""
|
||||
Gets source node name.
|
||||
|
||||
Returns:
|
||||
Source node name
|
||||
"""
|
||||
return self.source_node_name
|
||||
|
||||
def get_dest_node_name(self) -> str:
|
||||
"""
|
||||
Gets destination node name.
|
||||
|
||||
Returns:
|
||||
Destination node name
|
||||
"""
|
||||
return self.dest_node_name
|
||||
|
||||
def get_bandwidth(self) -> int:
|
||||
"""
|
||||
Gets bandwidth of link.
|
||||
|
||||
Returns:
|
||||
Link bandwidth (bps)
|
||||
"""
|
||||
return self.bandwidth
|
||||
|
||||
def get_protocol_list(self) -> List[Protocol]:
|
||||
"""
|
||||
Gets list of protocols on this link.
|
||||
|
||||
Returns:
|
||||
List of protocols on this link
|
||||
"""
|
||||
return self.protocol_list
|
||||
|
||||
def get_current_load(self) -> int:
|
||||
"""
|
||||
Gets current total load on this link.
|
||||
|
||||
Returns:
|
||||
Total load on this link (bps)
|
||||
"""
|
||||
total_load = 0
|
||||
for protocol in self.protocol_list:
|
||||
total_load += protocol.get_load()
|
||||
return total_load
|
||||
|
||||
def add_protocol_load(self, _protocol: str, _load: int) -> None:
|
||||
"""
|
||||
Adds a loading to a protocol on this link.
|
||||
|
||||
Args:
|
||||
_protocol: The protocol to load
|
||||
_load: The amount to load (bps)
|
||||
"""
|
||||
for protocol in self.protocol_list:
|
||||
if protocol.get_name() == _protocol:
|
||||
protocol.add_load(_load)
|
||||
else:
|
||||
pass
|
||||
|
||||
def clear_traffic(self) -> None:
|
||||
"""Clears all traffic on this link."""
|
||||
for protocol in self.protocol_list:
|
||||
protocol.clear_load()
|
||||
@@ -5,17 +5,16 @@ from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
from src.primaite.config.load import load
|
||||
from src.primaite.game.session import PrimaiteSession
|
||||
|
||||
# from src.primaite.primaite_session import PrimaiteSession
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
@@ -31,27 +30,17 @@ def run(
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
session = PrimaiteSession(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
session.evaluate()
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
parser.add_argument("--config")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if not args.config:
|
||||
_LOGGER.error("Please provide a config file using the --config " "argument")
|
||||
|
||||
run(session_path=args.config)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Nodes represent network hosts in the simulation."""
|
||||
@@ -1,208 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActiveNode(Node):
|
||||
"""Active Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an active node.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
:param node_type: The node type (enum)
|
||||
:param priority: The node priority (enum)
|
||||
:param hardware_state: The node Hardware State
|
||||
:param ip_address: The node IP address
|
||||
:param software_state: The node Software State
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
self.ip_address: str = ip_address
|
||||
# Related to Software
|
||||
self._software_state: SoftwareState = software_state
|
||||
self.patching_count: int = 0
|
||||
# Related to File System
|
||||
self.file_system_state_actual: FileSystemState = file_system_state
|
||||
self.file_system_state_observed: FileSystemState = file_system_state
|
||||
self.file_system_scanning: bool = False
|
||||
self.file_system_scanning_count: int = 0
|
||||
self.file_system_action_count: int = 0
|
||||
|
||||
@property
|
||||
def software_state(self) -> SoftwareState:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:return: The software_state.
|
||||
"""
|
||||
return self._software_state
|
||||
|
||||
@software_state.setter
|
||||
def software_state(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Get the software_state.
|
||||
|
||||
:param software_state: Software State.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be "
|
||||
f"changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def set_software_state_if_not_compromised(self, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets Software State if the node is not compromised.
|
||||
|
||||
Args:
|
||||
software_state: Software State
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
if self._software_state != SoftwareState.COMPROMISED:
|
||||
self._software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
self.patching_count = self.config_values.os_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so OS State cannot be changed."
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.software_state:{self._software_state}"
|
||||
)
|
||||
|
||||
def update_os_patching_status(self) -> None:
|
||||
"""Updates operating system status based on patching cycle."""
|
||||
self.patching_count -= 1
|
||||
if self.patching_count <= 0:
|
||||
self.patching_count = 0
|
||||
self._software_state = SoftwareState.GOOD
|
||||
|
||||
def set_file_system_state(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed).
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so File System State "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState) -> None:
|
||||
"""
|
||||
Sets the file system state (actual and observed) if not in a compromised state.
|
||||
|
||||
Use for green PoL to prevent it overturning a compromised state
|
||||
|
||||
Args:
|
||||
file_system_state: File system state
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
if (
|
||||
self.file_system_state_actual != FileSystemState.CORRUPT
|
||||
and self.file_system_state_actual != FileSystemState.DESTROYED
|
||||
):
|
||||
self.file_system_state_actual = file_system_state
|
||||
|
||||
if file_system_state == FileSystemState.REPAIRING:
|
||||
self.file_system_action_count = self.config_values.file_system_repairing_limit
|
||||
self.file_system_state_observed = FileSystemState.REPAIRING
|
||||
elif file_system_state == FileSystemState.RESTORING:
|
||||
self.file_system_action_count = self.config_values.file_system_restoring_limit
|
||||
self.file_system_state_observed = FileSystemState.RESTORING
|
||||
elif file_system_state == FileSystemState.GOOD:
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so File System State (if not "
|
||||
f"compromised) cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.file_system_state.actual:{self.file_system_state_actual}"
|
||||
)
|
||||
|
||||
def start_file_system_scan(self) -> None:
|
||||
"""Starts a file system scan."""
|
||||
self.file_system_scanning = True
|
||||
self.file_system_scanning_count = self.config_values.file_system_scanning_limit
|
||||
|
||||
def update_file_system_state(self) -> None:
|
||||
"""Updates file system status based on scanning/restore/repair cycle."""
|
||||
# Deprecate both the action count (for restoring or reparing) and the scanning count
|
||||
self.file_system_action_count -= 1
|
||||
self.file_system_scanning_count -= 1
|
||||
|
||||
# Reparing / Restoring updates
|
||||
if self.file_system_action_count <= 0:
|
||||
self.file_system_action_count = 0
|
||||
if (
|
||||
self.file_system_state_actual == FileSystemState.REPAIRING
|
||||
or self.file_system_state_actual == FileSystemState.RESTORING
|
||||
):
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.file_system_state_observed = FileSystemState.GOOD
|
||||
|
||||
# Scanning updates
|
||||
if self.file_system_scanning == True and self.file_system_scanning_count < 0:
|
||||
self.file_system_state_observed = self.file_system_state_actual
|
||||
self.file_system_scanning = False
|
||||
self.file_system_scanning_count = 0
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the reset count & makes software and file state to GOOD."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting software and file state to GOOD."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
self.file_system_state_actual = FileSystemState.GOOD
|
||||
self.software_state = SoftwareState.GOOD
|
||||
@@ -1,79 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The base Node class."""
|
||||
from typing import Final
|
||||
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
|
||||
class Node:
|
||||
"""Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a node.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
:param node_type: The type of the node.
|
||||
:param priority: The priority of the node.
|
||||
:param hardware_state: The state of the node.
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
self.node_id: Final[str] = node_id
|
||||
self.name: Final[str] = name
|
||||
self.node_type: Final[NodeType] = node_type
|
||||
self.priority = priority
|
||||
self.hardware_state: HardwareState = hardware_state
|
||||
self.resetting_count: int = 0
|
||||
self.config_values: TrainingConfig = config_values
|
||||
self.booting_count: int = 0
|
||||
self.shutting_down_count: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns the name of the node."""
|
||||
return self.name
|
||||
|
||||
def turn_on(self) -> None:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self) -> None:
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
self.shutting_down_count = self.config_values.node_shutdown_duration
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets the node state to Resetting and starts the reset count."""
|
||||
self.hardware_state = HardwareState.RESETTING
|
||||
self.resetting_count = self.config_values.node_reset_duration
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Updates the resetting count."""
|
||||
self.resetting_count -= 1
|
||||
if self.resetting_count <= 0:
|
||||
self.resetting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self) -> None:
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
self.shutting_down_count = 0
|
||||
self.hardware_state = HardwareState.OFF
|
||||
@@ -1,94 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLType, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionGreen(object):
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_node_id: str,
|
||||
_node_pol_type: "NodePOLType",
|
||||
_service_name: str,
|
||||
_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction.
|
||||
|
||||
:param _id: The node state instruction id
|
||||
:param _start_step: The start step of the instruction
|
||||
:param _end_step: The end step of the instruction
|
||||
:param _node_id: The id of the associated node
|
||||
:param _node_pol_type: The pattern of life type
|
||||
:param _service_name: The service name
|
||||
:param _state: The state (node or service)
|
||||
"""
|
||||
self.id = _id
|
||||
self.start_step = _start_step
|
||||
self.end_step = _end_step
|
||||
self.node_id = _node_id
|
||||
self.node_pol_type: "NodePOLType" = _node_pol_type
|
||||
self.service_name: str = _service_name # Not used when not a service instruction
|
||||
# TODO: confirm type of state
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _state
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
return self.node_id
|
||||
|
||||
def get_node_pol_type(self) -> "NodePOLType":
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
return self.node_pol_type
|
||||
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
return self.state
|
||||
@@ -1,143 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from primaite.common.enums import NodePOLType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodePOLInitiator, SoftwareState
|
||||
|
||||
|
||||
class NodeStateInstructionRed:
|
||||
"""The Node State Instruction class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_target_node_id: str,
|
||||
_pol_initiator: "NodePOLInitiator",
|
||||
_pol_type: NodePOLType,
|
||||
pol_protocol: str,
|
||||
_pol_state: Union["HardwareState", "SoftwareState", "FileSystemState"],
|
||||
_pol_source_node_id: str,
|
||||
_pol_source_node_service: str,
|
||||
_pol_source_node_service_state: str,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the Node State Instruction for the red agent.
|
||||
|
||||
:param _id: The node state instruction id
|
||||
:param _start_step: The start step of the instruction
|
||||
:param _end_step: The end step of the instruction
|
||||
:param _target_node_id: The id of the associated node
|
||||
:param -pol_initiator: The way the PoL is applied (DIRECT, IER or SERVICE)
|
||||
:param _pol_type: The pattern of life type
|
||||
:param pol_protocol: The pattern of life protocol/service affected
|
||||
:param _pol_state: The state (node or service)
|
||||
:param _pol_source_node_id: The source node Id (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service: The source node service (used for initiator type SERVICE)
|
||||
:param _pol_source_node_service_state: The source node service state (used for initiator type SERVICE)
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.target_node_id: str = _target_node_id
|
||||
self.initiator: "NodePOLInitiator" = _pol_initiator
|
||||
self.pol_type: NodePOLType = _pol_type
|
||||
self.service_name: str = pol_protocol # Not used when not a service instruction
|
||||
self.state: Union["HardwareState", "SoftwareState", "FileSystemState"] = _pol_state
|
||||
self.source_node_id: str = _pol_source_node_id
|
||||
self.source_node_service: str = _pol_source_node_service
|
||||
self.source_node_service_state = _pol_source_node_service_state
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets the start step.
|
||||
|
||||
Returns:
|
||||
The start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets the end step.
|
||||
|
||||
Returns:
|
||||
The end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_target_node_id(self) -> str:
|
||||
"""
|
||||
Gets the node ID.
|
||||
|
||||
Returns:
|
||||
The node ID
|
||||
"""
|
||||
return self.target_node_id
|
||||
|
||||
def get_initiator(self) -> "NodePOLInitiator":
|
||||
"""
|
||||
Gets the initiator.
|
||||
|
||||
Returns:
|
||||
The initiator
|
||||
"""
|
||||
return self.initiator
|
||||
|
||||
def get_pol_type(self) -> NodePOLType:
|
||||
"""
|
||||
Gets the node pattern of life type (enum).
|
||||
|
||||
Returns:
|
||||
The node pattern of life type (enum)
|
||||
"""
|
||||
return self.pol_type
|
||||
|
||||
def get_service_name(self) -> str:
|
||||
"""
|
||||
Gets the service name.
|
||||
|
||||
Returns:
|
||||
The service name
|
||||
"""
|
||||
return self.service_name
|
||||
|
||||
def get_state(self) -> Union["HardwareState", "SoftwareState", "FileSystemState"]:
|
||||
"""
|
||||
Gets the state (node or service).
|
||||
|
||||
Returns:
|
||||
The state (node or service)
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets the source node id (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node id
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_source_node_service(self) -> str:
|
||||
"""
|
||||
Gets the source node service (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service
|
||||
"""
|
||||
return self.source_node_service
|
||||
|
||||
def get_source_node_service_state(self) -> str:
|
||||
"""
|
||||
Gets the source node service state (used for initiator type SERVICE).
|
||||
|
||||
Returns:
|
||||
The source node service state
|
||||
"""
|
||||
return self.source_node_service_state
|
||||
@@ -1,42 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.node import Node
|
||||
|
||||
|
||||
class PassiveNode(Node):
|
||||
"""The Passive Node class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a passive node.
|
||||
|
||||
:param node_id: The node id.
|
||||
:param name: The name of the node.
|
||||
:param node_type: The type of the node.
|
||||
:param priority: The priority of the node.
|
||||
:param hardware_state: The state of the node.
|
||||
:param config_values: Config values.
|
||||
"""
|
||||
# Pass through to Super for now
|
||||
super().__init__(node_id, name, node_type, priority, hardware_state, config_values)
|
||||
|
||||
@property
|
||||
def ip_address(self) -> str:
|
||||
"""
|
||||
Gets the node IP address as an empty string.
|
||||
|
||||
No concept of IP address for passive nodes for now.
|
||||
|
||||
:return: The node IP address.
|
||||
"""
|
||||
return ""
|
||||
@@ -1,190 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
|
||||
_LOGGER: Final[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceNode(ActiveNode):
|
||||
"""ServiceNode class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
name: str,
|
||||
node_type: NodeType,
|
||||
priority: Priority,
|
||||
hardware_state: HardwareState,
|
||||
ip_address: str,
|
||||
software_state: SoftwareState,
|
||||
file_system_state: FileSystemState,
|
||||
config_values: TrainingConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a Service Node.
|
||||
|
||||
:param node_id: The node ID
|
||||
:param name: The node name
|
||||
:param node_type: The node type (enum)
|
||||
:param priority: The node priority (enum)
|
||||
:param hardware_state: The node Hardware State
|
||||
:param ip_address: The node IP address
|
||||
:param software_state: The node Software State
|
||||
:param file_system_state: The node file system state
|
||||
:param config_values: The config values
|
||||
"""
|
||||
super().__init__(
|
||||
node_id,
|
||||
name,
|
||||
node_type,
|
||||
priority,
|
||||
hardware_state,
|
||||
ip_address,
|
||||
software_state,
|
||||
file_system_state,
|
||||
config_values,
|
||||
)
|
||||
self.services: Dict[str, Service] = {}
|
||||
|
||||
def add_service(self, service: Service) -> None:
|
||||
"""
|
||||
Adds a service to the node.
|
||||
|
||||
:param service: The service to add
|
||||
"""
|
||||
self.services[service.name] = service
|
||||
|
||||
def has_service(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is on a node.
|
||||
|
||||
:param protocol_name: The service (protocol)e.
|
||||
:return: True if service (protocol) is on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def service_running(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in a running state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in a running state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
if service_value.software_state != SoftwareState.PATCHING:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def service_is_overwhelmed(self, protocol_name: str) -> bool:
|
||||
"""
|
||||
Indicates whether a service is in an overwhelmed state on the node.
|
||||
|
||||
:param protocol_name: The service (protocol)
|
||||
:return: True if service (protocol) is in an overwhelmed state on the node, otherwise False.
|
||||
"""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_key == protocol_name:
|
||||
if service_value.software_state == SoftwareState.OVERWHELMED:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
def set_service_state(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
# Can't set to compromised if you're in a patching state
|
||||
if (
|
||||
software_state == SoftwareState.COMPROMISED
|
||||
and service_value.software_state != SoftwareState.PATCHING
|
||||
) or software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.services[<key>]:{protocol_name}, "
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState) -> None:
|
||||
"""
|
||||
Sets the software_state of a service (protocol) on the node.
|
||||
|
||||
Done if the software_state is not "compromised".
|
||||
|
||||
:param protocol_name: The service (protocol).
|
||||
:param software_state: The software_state.
|
||||
"""
|
||||
if self.hardware_state != HardwareState.OFF:
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
if service_value.software_state != SoftwareState.COMPROMISED:
|
||||
service_value.software_state = software_state
|
||||
if software_state == SoftwareState.PATCHING:
|
||||
service_value.patching_count = self.config_values.service_patching_duration
|
||||
else:
|
||||
_LOGGER.info(
|
||||
f"The Nodes hardware state is OFF so the state of a service "
|
||||
f"cannot be changed. "
|
||||
f"Node.node_id:{self.node_id}, "
|
||||
f"Node.hardware_state:{self.hardware_state}, "
|
||||
f"Node.services[<key>]:{protocol_name}, "
|
||||
f"Node.services[<key>].software_state:{software_state}"
|
||||
)
|
||||
|
||||
def get_service_state(self, protocol_name: str) -> SoftwareState:
|
||||
"""
|
||||
Gets the state of a service.
|
||||
|
||||
:return: The software_state of the service.
|
||||
"""
|
||||
service_key = protocol_name
|
||||
service_value = self.services.get(service_key)
|
||||
if service_value:
|
||||
return service_value.software_state
|
||||
|
||||
def update_services_patching_status(self) -> None:
|
||||
"""Updates the patching counter for any service that are patching."""
|
||||
for service_key, service_value in self.services.items():
|
||||
if service_value.software_state == SoftwareState.PATCHING:
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self) -> None:
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self) -> None:
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
0
src/primaite/notebooks/.gitkeep
Normal file
0
src/primaite/notebooks/.gitkeep
Normal file
@@ -1,107 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.network.networks import arcd_uc2_network\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = arcd_uc2_network()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### set up some services to test if actions are working"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_serv = net.get_node_by_hostname('database_server')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.simulator.system.services.database_service import DatabaseService"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_svc = DatabaseService(file_system=db_serv.file_system)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_serv.install_service(db_svc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_serv.describe_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Pattern of Life- Represents the actions of users on the network."""
|
||||
@@ -1,264 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_iers(
|
||||
network: MultiGraph,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
links: Dict[str, Link],
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link pattern of life).
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
nodes: The nodes within the environment
|
||||
links: The links within the environment
|
||||
iers: The IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying IERs")
|
||||
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
protocol = ier_value.get_protocol()
|
||||
port = ier_value.get_port()
|
||||
load = ier_value.get_load()
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
dest_valid = True
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
# TODO: should be using isinstance rather than checking node type attribute. IE. just because it's a switch
|
||||
# doesn't mean it has a software state? It could be a PassiveNode or ActiveNode
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if (
|
||||
source_node.hardware_state == HardwareState.ON
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
source_valid = False
|
||||
elif source_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
# TO DO
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if (
|
||||
source_node.hardware_state == HardwareState.ON
|
||||
and source_node.software_state != SoftwareState.PATCHING
|
||||
):
|
||||
if source_node.has_service(protocol):
|
||||
if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol):
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
dest_valid = False
|
||||
elif dest_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING:
|
||||
if dest_node.has_service(protocol):
|
||||
if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol):
|
||||
dest_valid = True
|
||||
else:
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.ip_address
|
||||
+ ", dest: "
|
||||
+ dest_node.ip_address
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
|
||||
# Check whether both the source and destination are valid, and there's no ACL block
|
||||
if source_valid and dest_valid and not acl_block:
|
||||
# Load up the link(s) with the traffic
|
||||
|
||||
if _VERBOSE:
|
||||
print("Source, Dest and ACL valid")
|
||||
|
||||
# Get the shortest path (i.e. nodes) between source and destination
|
||||
path_node_list = shortest_path(network, source_node, dest_node)
|
||||
path_node_list_length = len(path_node_list)
|
||||
path_valid = True
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
for node in path_node_list:
|
||||
if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
# Now apply the new loads to the links
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
|
||||
def apply_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
node_pol: Dict[str, NodeStateInstructionGreen],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
node_pol: The node pattern of life to apply
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying Node PoL")
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
node_id = node_instruction.get_node_id()
|
||||
node_pol_type = node_instruction.get_node_pol_type()
|
||||
service_name = node_instruction.get_service_name()
|
||||
state = node_instruction.get_state()
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
node = nodes[node_id]
|
||||
|
||||
if node_pol_type == NodePOLType.OPERATING:
|
||||
# Change hardware state
|
||||
node.hardware_state = state
|
||||
elif node_pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_software_state_if_not_compromised(state)
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
# Don't allow PoL to fix something that is compromised. Only the Blue agent can do this
|
||||
if isinstance(node, ServiceNode):
|
||||
node.set_service_state_if_not_compromised(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
node.set_file_system_state_if_not_compromised(state)
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
@@ -1,147 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
Used to represent an information flow from source to destination.
|
||||
"""
|
||||
|
||||
|
||||
class IER(object):
|
||||
"""Information Exchange Requirement class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_id: str,
|
||||
_start_step: int,
|
||||
_end_step: int,
|
||||
_load: int,
|
||||
_protocol: str,
|
||||
_port: str,
|
||||
_source_node_id: str,
|
||||
_dest_node_id: str,
|
||||
_mission_criticality: int,
|
||||
_running: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an Information Exchange Request.
|
||||
|
||||
:param _id: The IER id
|
||||
:param _start_step: The step when this IER should start
|
||||
:param _end_step: The step when this IER should end
|
||||
:param _load: The load this IER should put on a link (bps)
|
||||
:param _protocol: The protocol of this IER
|
||||
:param _port: The port this IER runs on
|
||||
:param _source_node_id: The source node ID
|
||||
:param _dest_node_id: The destination node ID
|
||||
:param _mission_criticality: Criticality of this IER to the mission (0 none, 5 mission critical)
|
||||
:param _running: Indicates whether the IER is currently running
|
||||
"""
|
||||
self.id: str = _id
|
||||
self.start_step: int = _start_step
|
||||
self.end_step: int = _end_step
|
||||
self.source_node_id: str = _source_node_id
|
||||
self.dest_node_id: str = _dest_node_id
|
||||
self.load: int = _load
|
||||
self.protocol: str = _protocol
|
||||
self.port: str = _port
|
||||
self.mission_criticality: int = _mission_criticality
|
||||
self.running: bool = _running
|
||||
|
||||
def get_id(self) -> str:
|
||||
"""
|
||||
Gets IER ID.
|
||||
|
||||
Returns:
|
||||
IER ID
|
||||
"""
|
||||
return self.id
|
||||
|
||||
def get_start_step(self) -> int:
|
||||
"""
|
||||
Gets IER start step.
|
||||
|
||||
Returns:
|
||||
IER start step
|
||||
"""
|
||||
return self.start_step
|
||||
|
||||
def get_end_step(self) -> int:
|
||||
"""
|
||||
Gets IER end step.
|
||||
|
||||
Returns:
|
||||
IER end step
|
||||
"""
|
||||
return self.end_step
|
||||
|
||||
def get_load(self) -> int:
|
||||
"""
|
||||
Gets IER load.
|
||||
|
||||
Returns:
|
||||
IER load
|
||||
"""
|
||||
return self.load
|
||||
|
||||
def get_protocol(self) -> str:
|
||||
"""
|
||||
Gets IER protocol.
|
||||
|
||||
Returns:
|
||||
IER protocol
|
||||
"""
|
||||
return self.protocol
|
||||
|
||||
def get_port(self) -> str:
|
||||
"""
|
||||
Gets IER port.
|
||||
|
||||
Returns:
|
||||
IER port
|
||||
"""
|
||||
return self.port
|
||||
|
||||
def get_source_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER source node ID.
|
||||
|
||||
Returns:
|
||||
IER source node ID
|
||||
"""
|
||||
return self.source_node_id
|
||||
|
||||
def get_dest_node_id(self) -> str:
|
||||
"""
|
||||
Gets IER destination node ID.
|
||||
|
||||
Returns:
|
||||
IER destination node ID
|
||||
"""
|
||||
return self.dest_node_id
|
||||
|
||||
def get_is_running(self) -> bool:
|
||||
"""
|
||||
Informs whether the IER is currently running.
|
||||
|
||||
Returns:
|
||||
True if running
|
||||
"""
|
||||
return self.running
|
||||
|
||||
def set_is_running(self, _value: bool) -> None:
|
||||
"""
|
||||
Sets the running state of the IER.
|
||||
|
||||
Args:
|
||||
_value: running status
|
||||
"""
|
||||
self.running = _value
|
||||
|
||||
def get_mission_criticality(self) -> int:
|
||||
"""
|
||||
Gets the IER mission criticality (used in the reward function).
|
||||
|
||||
Returns:
|
||||
Mission criticality value (0 lowest to 5 highest)
|
||||
"""
|
||||
return self.mission_criticality
|
||||
@@ -1,353 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
from typing import Dict
|
||||
|
||||
from networkx import MultiGraph, shortest_path
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.common.custom_typing import NodeUnion
|
||||
from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState
|
||||
from primaite.links.link import Link
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
from primaite.pol.ier import IER
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
_VERBOSE: bool = False
|
||||
|
||||
|
||||
def apply_red_agent_iers(
|
||||
network: MultiGraph,
|
||||
nodes: Dict[str, NodeUnion],
|
||||
links: Dict[str, Link],
|
||||
iers: Dict[str, IER],
|
||||
acl: AccessControlList,
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies IERs to the links (link POL) resulting from red agent attack.
|
||||
|
||||
Args:
|
||||
network: The network modelled in the environment
|
||||
nodes: The nodes within the environment
|
||||
links: The links within the environment
|
||||
iers: The red agent IERs to apply to the links
|
||||
acl: The Access Control List
|
||||
step: The step number.
|
||||
"""
|
||||
# Go through each IER and check the conditions for it being applied
|
||||
# If everything is in place, apply the IER protocol load to the relevant links
|
||||
for ier_key, ier_value in iers.items():
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
protocol = ier_value.get_protocol()
|
||||
port = ier_value.get_port()
|
||||
load = ier_value.get_load()
|
||||
source_node_id = ier_value.get_source_node_id()
|
||||
dest_node_id = ier_value.get_dest_node_id()
|
||||
|
||||
# Need to set the running status to false first for all IERs
|
||||
ier_value.set_is_running(False)
|
||||
|
||||
source_valid = True
|
||||
dest_valid = True
|
||||
acl_block = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
|
||||
# Get the source and destination node for this link
|
||||
source_node = nodes[source_node_id]
|
||||
dest_node = nodes[dest_node_id]
|
||||
|
||||
# 1. Check the source node situation
|
||||
if source_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
source_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
source_valid = False
|
||||
elif source_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
# TO DO
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
# TODO: this occurs after ruling out the possibility that the node is a switch or an actuator, but it
|
||||
# could still be a passive/active node, therefore it won't have a hardware_state. The logic here needs
|
||||
# to change according to duck typing.
|
||||
if source_node.hardware_state == HardwareState.ON:
|
||||
if source_node.has_service(protocol):
|
||||
# Red agents IERs can only be valid if the source service is in a compromised state
|
||||
if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED:
|
||||
source_valid = True
|
||||
else:
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
source_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
source_valid = False
|
||||
|
||||
# 2. Check the dest node situation
|
||||
if dest_node.node_type == NodeType.SWITCH:
|
||||
# It's a switch
|
||||
if dest_node.hardware_state == HardwareState.ON:
|
||||
dest_valid = True
|
||||
else:
|
||||
# IER no longer valid
|
||||
dest_valid = False
|
||||
elif dest_node.node_type == NodeType.ACTUATOR:
|
||||
# It's an actuator
|
||||
pass
|
||||
else:
|
||||
# It's not a switch or an actuator (so active node)
|
||||
if dest_node.hardware_state == HardwareState.ON:
|
||||
if dest_node.has_service(protocol):
|
||||
# We don't care what state the destination service is in for an IER
|
||||
dest_valid = True
|
||||
else:
|
||||
# Do nothing - IER is not valid on this node
|
||||
# (This shouldn't happen if the IER has been written correctly)
|
||||
dest_valid = False
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
dest_valid = False
|
||||
|
||||
# 3. Check that the ACL doesn't block it
|
||||
acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port)
|
||||
if acl_block:
|
||||
if _VERBOSE:
|
||||
print(
|
||||
"ACL block on source: "
|
||||
+ source_node.ip_address
|
||||
+ ", dest: "
|
||||
+ dest_node.ip_address
|
||||
+ ", protocol: "
|
||||
+ protocol
|
||||
+ ", port: "
|
||||
+ port
|
||||
)
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("No ACL block")
|
||||
|
||||
# Check whether both the source and destination are valid, and there's no ACL block
|
||||
if source_valid and dest_valid and not acl_block:
|
||||
# Load up the link(s) with the traffic
|
||||
|
||||
if _VERBOSE:
|
||||
print("Source, Dest and ACL valid")
|
||||
|
||||
# Get the shortest path (i.e. nodes) between source and destination
|
||||
path_node_list = shortest_path(network, source_node, dest_node)
|
||||
path_node_list_length = len(path_node_list)
|
||||
path_valid = True
|
||||
|
||||
# We might have a switch in the path, so check all nodes are operational
|
||||
# We're assuming here that red agents can get past switches that are patching
|
||||
for node in path_node_list:
|
||||
if node.hardware_state != HardwareState.ON:
|
||||
path_valid = False
|
||||
|
||||
if path_valid:
|
||||
if _VERBOSE:
|
||||
print("Applying IER to link(s)")
|
||||
count = 0
|
||||
link_capacity_exceeded = False
|
||||
|
||||
# Check that the link capacity is not exceeded by the new load
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1])
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Check whether the new load exceeds the bandwidth
|
||||
if (link.get_current_load() + load) > link.get_bandwidth():
|
||||
link_capacity_exceeded = True
|
||||
if _VERBOSE:
|
||||
print("Link capacity exceeded")
|
||||
pass
|
||||
count += 1
|
||||
|
||||
# Check whether the link capacity for any links on this path have been exceeded
|
||||
if link_capacity_exceeded == False:
|
||||
# Now apply the new loads to the links
|
||||
count = 0
|
||||
while count < path_node_list_length - 1:
|
||||
# Get the link between the next two nodes
|
||||
edge_dict = network.get_edge_data(
|
||||
path_node_list[count],
|
||||
path_node_list[count + 1],
|
||||
)
|
||||
link_id = edge_dict[0].get("id")
|
||||
link = links[link_id]
|
||||
# Add the load from this IER
|
||||
link.add_protocol_load(protocol, load)
|
||||
count += 1
|
||||
# This IER is now valid, so set it to running
|
||||
ier_value.set_is_running(True)
|
||||
if _VERBOSE:
|
||||
print("Red IER was allowed to run in step " + str(step))
|
||||
else:
|
||||
# One of the nodes is not operational
|
||||
if _VERBOSE:
|
||||
print("Path not valid - one or more nodes not operational")
|
||||
pass
|
||||
|
||||
else:
|
||||
if _VERBOSE:
|
||||
print("Red IER was NOT allowed to run in step " + str(step))
|
||||
print("Source, Dest or ACL were not valid")
|
||||
pass
|
||||
# ------------------------------------
|
||||
else:
|
||||
# Do nothing - IER no longer valid
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def apply_red_agent_node_pol(
|
||||
nodes: Dict[str, NodeUnion],
|
||||
iers: Dict[str, IER],
|
||||
node_pol: Dict[str, NodeStateInstructionRed],
|
||||
step: int,
|
||||
) -> None:
|
||||
"""
|
||||
Applies node pattern of life.
|
||||
|
||||
Args:
|
||||
nodes: The nodes within the environment
|
||||
iers: The red agent IERs
|
||||
node_pol: The red agent node pattern of life to apply
|
||||
step: The step number.
|
||||
"""
|
||||
if _VERBOSE:
|
||||
print("Applying Node Red Agent PoL")
|
||||
|
||||
for key, node_instruction in node_pol.items():
|
||||
start_step = node_instruction.get_start_step()
|
||||
stop_step = node_instruction.get_end_step()
|
||||
target_node_id = node_instruction.get_target_node_id()
|
||||
initiator = node_instruction.get_initiator()
|
||||
pol_type = node_instruction.get_pol_type()
|
||||
service_name = node_instruction.get_service_name()
|
||||
state = node_instruction.get_state()
|
||||
source_node_id = node_instruction.get_source_node_id()
|
||||
source_node_service_name = node_instruction.get_source_node_service()
|
||||
source_node_service_state_value = node_instruction.get_source_node_service_state()
|
||||
|
||||
passed_checks = False
|
||||
|
||||
if step >= start_step and step <= stop_step:
|
||||
# continue --------------------------
|
||||
target_node: NodeUnion = nodes[target_node_id]
|
||||
|
||||
# check if the initiator type is a str, and if so, cast it as
|
||||
# NodePOLInitiator
|
||||
if isinstance(initiator, str):
|
||||
initiator = NodePOLInitiator[initiator]
|
||||
|
||||
# Based the action taken on the initiator type
|
||||
if initiator == NodePOLInitiator.DIRECT:
|
||||
# No conditions required, just apply the change
|
||||
passed_checks = True
|
||||
elif initiator == NodePOLInitiator.IER:
|
||||
# Need to check there is a red IER incoming
|
||||
passed_checks = is_red_ier_incoming(target_node, iers, pol_type)
|
||||
elif initiator == NodePOLInitiator.SERVICE:
|
||||
# Need to check the condition of a service on another node
|
||||
source_node = nodes[source_node_id]
|
||||
if source_node.has_service(source_node_service_name):
|
||||
if (
|
||||
source_node.get_service_state(source_node_service_name)
|
||||
== SoftwareState[source_node_service_state_value]
|
||||
):
|
||||
passed_checks = True
|
||||
else:
|
||||
# Do nothing, no matching state value
|
||||
pass
|
||||
else:
|
||||
# Do nothing, service not on this node
|
||||
pass
|
||||
else:
|
||||
_LOGGER.warning("Node Red Agent PoL not allowed - misconfiguration")
|
||||
|
||||
# Only apply the PoL if the checks have passed (based on the initiator type)
|
||||
if passed_checks:
|
||||
# Apply the change
|
||||
if pol_type == NodePOLType.OPERATING:
|
||||
# Change hardware state
|
||||
target_node.hardware_state = state
|
||||
elif pol_type == NodePOLType.OS:
|
||||
# Change OS state
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.software_state = state
|
||||
elif pol_type == NodePOLType.SERVICE:
|
||||
# Change a service state
|
||||
if isinstance(target_node, ServiceNode):
|
||||
target_node.set_service_state(service_name, state)
|
||||
else:
|
||||
# Change the file system status
|
||||
if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode):
|
||||
target_node.set_file_system_state(state)
|
||||
else:
|
||||
_LOGGER.debug("Node Red Agent PoL not allowed - did not pass checks")
|
||||
else:
|
||||
# PoL is not valid in this time step
|
||||
pass
|
||||
|
||||
|
||||
def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool:
|
||||
"""Checks if the RED IER is incoming.
|
||||
|
||||
:param node: Destination node of the IER
|
||||
:type node: NodeUnion
|
||||
:param iers: Directory of IERs
|
||||
:type iers: Dict[str,IER]
|
||||
:param node_pol_type: Type of Pattern-Of-Life
|
||||
:type node_pol_type: NodePOLType
|
||||
:return: Whether the RED IER is incoming.
|
||||
:rtype: bool
|
||||
"""
|
||||
node_id = node.node_id
|
||||
|
||||
for ier_key, ier_value in iers.items():
|
||||
if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id:
|
||||
if (
|
||||
node_pol_type == NodePOLType.OPERATING
|
||||
or node_pol_type == NodePOLType.OS
|
||||
or node_pol_type == NodePOLType.FILE
|
||||
):
|
||||
# It's looking to change hardware state, file system or SoftwareState, so valid
|
||||
return True
|
||||
elif node_pol_type == NodePOLType.SERVICE:
|
||||
# Check if the service is present on the node and running
|
||||
ier_protocol = ier_value.get_protocol()
|
||||
if isinstance(node, ServiceNode):
|
||||
if node.has_service(ier_protocol):
|
||||
if node.service_running(ier_protocol):
|
||||
# Matching service is present and running, so valid
|
||||
return True
|
||||
else:
|
||||
# Service is present, but not running
|
||||
return False
|
||||
else:
|
||||
# Service is not present
|
||||
return False
|
||||
else:
|
||||
# Not a service node
|
||||
return False
|
||||
else:
|
||||
# Shouldn't get here - instruction type is undefined
|
||||
return False
|
||||
else:
|
||||
# The IER destination is not this node, or the IER is not running
|
||||
return False
|
||||
@@ -1,228 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Tuple, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
|
||||
# from primaite.agents.rllib import RLlibAgent
|
||||
from primaite.agents.sb3 import SB3Agent
|
||||
from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
from primaite.utils.session_output_reader import all_transactions_dict, av_rewards_dict
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""
|
||||
The PrimaiteSession class.
|
||||
|
||||
Provides a single learning and evaluation entry point for all training and lay down configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
# set load_session to true
|
||||
self.is_load_session = True
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
|
||||
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path, legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
if self._training_config.agent_framework == AgentFramework.CUSTOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}")
|
||||
if self._training_config.agent_identifier == AgentIdentifier.HARDCODED:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
# Invalid AgentIdentifier ActionType combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
raise ValueError
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path,
|
||||
self.session_path,
|
||||
self.legacy_training_config,
|
||||
self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
# elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
# _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# # Ray RLlib Agent
|
||||
# self._agent_session = RLlibAgent(
|
||||
# self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
# )
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
raise ValueError
|
||||
|
||||
self.session_path: Path = self._agent_session.session_path
|
||||
self.timestamp_str: str = self._agent_session.timestamp_str
|
||||
self.learning_path: Path = self._agent_session.learning_path
|
||||
self.evaluation_path: Path = self._agent_session.evaluation_path
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.EVAL:
|
||||
self._agent_session.learn(**kwargs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-framework specific key word args.
|
||||
"""
|
||||
if not self._training_config.session_type == SessionType.TRAIN:
|
||||
self._agent_session.evaluate(**kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the agent."""
|
||||
self._agent_session.close()
|
||||
|
||||
def learn_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the learn av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_av_reward_per_episode_dict(self) -> Dict[int, float]:
|
||||
"""Get the eval av reward per episode from file."""
|
||||
csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv"
|
||||
return av_rewards_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def learn_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the learn all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.learning_path / csv_file)
|
||||
|
||||
def eval_all_transactions_dict(self) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""Get the eval all transactions from file."""
|
||||
csv_file = f"all_transactions_{self.timestamp_str}.csv"
|
||||
return all_transactions_dict(self.evaluation_path / csv_file)
|
||||
|
||||
def metadata_file_as_dict(self) -> Dict[str, Any]:
|
||||
"""Read the session_metadata.json file and return as a dict."""
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
return json.load(file)
|
||||
@@ -1,14 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run() -> None:
|
||||
"""Perform the full clean-up."""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import SimComponent
|
||||
from src.primaite.simulator.core import SimComponent
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Final, List, Literal, Tuple
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account, AccountType
|
||||
from src.primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
|
||||
from src.primaite.simulator.domain.account import Account, AccountType
|
||||
|
||||
|
||||
# placeholder while these objects don't yet exist
|
||||
|
||||
@@ -7,8 +7,8 @@ from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
from primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
|
||||
from src.primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
from src.primaite.simulator.file_system.file_type import FileType, get_file_type_from_extension
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.file_system.folder import Folder
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.file_system.file import File
|
||||
from src.primaite.simulator.file_system.file_type import FileType
|
||||
from src.primaite.simulator.file_system.folder import Folder
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -41,19 +41,19 @@ def convert_size(size_bytes: int) -> str:
|
||||
class FileSystemItemHealthStatus(Enum):
|
||||
"""Status of the FileSystemItem."""
|
||||
|
||||
GOOD = 0
|
||||
GOOD = 1
|
||||
"""File/Folder is OK."""
|
||||
|
||||
COMPROMISED = 1
|
||||
COMPROMISED = 2
|
||||
"""File/Folder is quarantined."""
|
||||
|
||||
CORRUPT = 2
|
||||
CORRUPT = 3
|
||||
"""File/Folder is corrupted."""
|
||||
|
||||
RESTORING = 3
|
||||
RESTORING = 4
|
||||
"""File/Folder is in the process of being restored."""
|
||||
|
||||
REPAIRING = 3
|
||||
REPAIRING = 5
|
||||
"""File/Folder is in the process of being repaired."""
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ class FileSystemItemABC(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["name"] = self.name
|
||||
state["status"] = self.health_status.name
|
||||
state["visible_status"] = self.visible_health_status.name
|
||||
state["health_status"] = self.health_status.value
|
||||
state["visible_status"] = self.visible_health_status.value
|
||||
state["previous_hash"] = self.previous_hash
|
||||
state["revealed_to_red"] = self.revealed_to_red
|
||||
return state
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
from src.primaite.simulator.core import RequestManager, RequestType
|
||||
from src.primaite.simulator.file_system.file import File
|
||||
from src.primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,12 +6,12 @@ from networkx import MultiGraph
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.network.hardware.base import Link, NIC, Node, SwitchPort
|
||||
from src.primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from src.primaite.simulator.network.hardware.nodes.router import Router
|
||||
from src.primaite.simulator.network.hardware.nodes.server import Server
|
||||
from src.primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -160,8 +160,8 @@ class Network(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"nodes": {i for i, node in self._node_id_map.items()},
|
||||
"links": {i: link.describe_state() for i, link in self._link_id_map.items()},
|
||||
"nodes": {uuid: node.describe_state() for uuid, node in self.nodes.items()},
|
||||
"links": {uuid: link.describe_state() for uuid, link in self.links.items()},
|
||||
}
|
||||
)
|
||||
return state
|
||||
@@ -218,7 +218,9 @@ class Network(SimComponent):
|
||||
_LOGGER.info(f"Removed node {node.uuid} from network {self.uuid}")
|
||||
self._node_request_manager.remove_request(name=node.uuid)
|
||||
|
||||
def connect(self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs) -> None:
|
||||
def connect(
|
||||
self, endpoint_a: Union[NIC, SwitchPort], endpoint_b: Union[NIC, SwitchPort], **kwargs
|
||||
) -> Optional[Link]:
|
||||
"""
|
||||
Connect two endpoints on the network by creating a link between their NICs/SwitchPorts.
|
||||
|
||||
@@ -245,6 +247,7 @@ class Network(SimComponent):
|
||||
self._nx_graph.add_edge(endpoint_a.parent.hostname, endpoint_b.parent.hostname)
|
||||
link.parent = self
|
||||
_LOGGER.debug(f"Added link {link.uuid} to connect {endpoint_a} and {endpoint_b}")
|
||||
return link
|
||||
|
||||
def remove_link(self, link: Link) -> None:
|
||||
"""Disconnect a link from the network.
|
||||
|
||||
@@ -10,22 +10,22 @@ from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.domain.account import Account
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.processes.process import Process
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from src.primaite.exceptions import NetworkError
|
||||
from src.primaite.simulator import SIM_OUTPUT
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.domain.account import Account
|
||||
from src.primaite.simulator.file_system.file_system import FileSystem
|
||||
from src.primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
|
||||
from src.primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from src.primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
from src.primaite.simulator.system.applications.application import Application
|
||||
from src.primaite.simulator.system.core.packet_capture import PacketCapture
|
||||
from src.primaite.simulator.system.core.session_manager import SessionManager
|
||||
from src.primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.system.processes.process import Process
|
||||
from src.primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -859,14 +859,14 @@ class ICMP:
|
||||
class NodeOperatingState(Enum):
|
||||
"""Enumeration of Node Operating States."""
|
||||
|
||||
OFF = 0
|
||||
"The node is powered off."
|
||||
ON = 1
|
||||
"The node is powered on."
|
||||
SHUTTING_DOWN = 2
|
||||
"The node is in the process of shutting down."
|
||||
OFF = 2
|
||||
"The node is powered off."
|
||||
BOOTING = 3
|
||||
"The node is in the process of booting up."
|
||||
SHUTTING_DOWN = 4
|
||||
"The node is in the process of shutting down."
|
||||
|
||||
|
||||
class Node(SimComponent):
|
||||
@@ -962,6 +962,7 @@ class Node(SimComponent):
|
||||
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
|
||||
if not kwargs.get("software_manager"):
|
||||
kwargs["software_manager"] = SoftwareManager(
|
||||
parent_node=self,
|
||||
sys_log=kwargs.get("sys_log"),
|
||||
session_manager=kwargs.get("session_manager"),
|
||||
file_system=kwargs.get("file_system"),
|
||||
@@ -1369,7 +1370,8 @@ class Node(SimComponent):
|
||||
self._service_request_manager.add_request(service.uuid, RequestType(func=service._request_manager))
|
||||
|
||||
def uninstall_service(self, service: Service) -> None:
|
||||
"""Uninstall and completely remove service from this node.
|
||||
"""
|
||||
Uninstall and completely remove service from this node.
|
||||
|
||||
:param service: Service object that is currently associated with this node.
|
||||
:type service: Service
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from primaite.simulator.network.hardware.base import NIC, Node
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from src.primaite.simulator.network.hardware.base import NIC, Node
|
||||
from src.primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from src.primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from src.primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
|
||||
|
||||
class Computer(Node):
|
||||
|
||||
@@ -7,12 +7,12 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.network.hardware.base import ARPCache, ICMP, NIC, Node
|
||||
from src.primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from src.primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
|
||||
class ACLAction(Enum):
|
||||
@@ -58,7 +58,14 @@ class ACLRule(SimComponent):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
class AccessControlList(SimComponent):
|
||||
@@ -104,11 +111,11 @@ class AccessControlList(SimComponent):
|
||||
RequestType(
|
||||
func=lambda request, context: self.add_rule(
|
||||
ACLAction[request[0]],
|
||||
IPProtocol[request[1]],
|
||||
IPv4Address[request[2]],
|
||||
Port[request[3]],
|
||||
IPv4Address[request[4]],
|
||||
Port[request[5]],
|
||||
None if request[1] == "ALL" else IPProtocol[request[1]],
|
||||
IPv4Address(request[2]),
|
||||
None if request[3] == "ALL" else Port[request[3]],
|
||||
IPv4Address(request[4]),
|
||||
None if request[5] == "ALL" else Port[request[5]],
|
||||
int(request[6]),
|
||||
)
|
||||
),
|
||||
@@ -123,7 +130,12 @@ class AccessControlList(SimComponent):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["implicit_action"] = self.implicit_action.value
|
||||
state["implicit_rule"] = self.implicit_rule.describe_state()
|
||||
state["max_acl_rules"] = self.max_acl_rules
|
||||
state["acl"] = {i: r.describe_state() if isinstance(r, ACLRule) else None for i, r in enumerate(self._acl)}
|
||||
return state
|
||||
|
||||
@property
|
||||
def acl(self) -> List[Optional[ACLRule]]:
|
||||
@@ -648,7 +660,10 @@ class Router(Node):
|
||||
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
pass
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from src.primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
|
||||
|
||||
class Server(Computer):
|
||||
|
||||
@@ -3,10 +3,9 @@ from typing import Dict
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.exceptions import NetworkError
|
||||
from primaite.links.link import Link
|
||||
from primaite.simulator.network.hardware.base import Node, SwitchPort
|
||||
from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from src.primaite.exceptions import NetworkError
|
||||
from src.primaite.simulator.network.hardware.base import Link, Node, SwitchPort
|
||||
from src.primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -55,12 +54,11 @@ class Switch(Node):
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
"""
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
"num_ports": self.num_ports, # redundant?
|
||||
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
|
||||
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
|
||||
}
|
||||
state = super().describe_state()
|
||||
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
|
||||
state["num_ports"] = self.num_ports # redundant?
|
||||
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
|
||||
return state
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
"""
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
from ipaddress import IPv4Address
|
||||
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from src.primaite.simulator.network.container import Network
|
||||
from src.primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from src.primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from src.primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
from src.primaite.simulator.network.hardware.nodes.server import Server
|
||||
from src.primaite.simulator.network.hardware.nodes.switch import Switch
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from src.primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from src.primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from src.primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from src.primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
|
||||
from src.primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
|
||||
|
||||
def client_server_routed() -> Network:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from src.primaite.simulator.network.protocols.packet import DataPacket
|
||||
|
||||
|
||||
class ARPEntry(BaseModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from src.primaite.simulator.network.protocols.packet import DataPacket
|
||||
|
||||
|
||||
class DNSRequest(BaseModel):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from src.primaite.simulator.network.protocols.packet import DataPacket
|
||||
|
||||
|
||||
class FTPCommand(Enum):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from src.primaite.simulator.network.protocols.packet import DataPacket
|
||||
|
||||
|
||||
class HttpRequestMethod(Enum):
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.arp import ARPPacket
|
||||
from primaite.simulator.network.protocols.packet import DataPacket
|
||||
from primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
|
||||
from primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
|
||||
from primaite.simulator.network.utils import convert_bytes_to_megabits
|
||||
from src.primaite.simulator.network.protocols.arp import ARPPacket
|
||||
from src.primaite.simulator.network.protocols.packet import DataPacket
|
||||
from src.primaite.simulator.network.transmission.network_layer import ICMPPacket, IPPacket, IPProtocol
|
||||
from src.primaite.simulator.network.transmission.primaite_layer import PrimaiteHeader
|
||||
from src.primaite.simulator.network.transmission.transport_layer import TCPHeader, UDPHeader
|
||||
from src.primaite.simulator.network.utils import convert_bytes_to_megabits
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Dict
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.domain.controller import DomainController
|
||||
from primaite.simulator.network.container import Network
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.domain.controller import DomainController
|
||||
from src.primaite.simulator.network.container import Network
|
||||
|
||||
|
||||
class Simulation(SimComponent):
|
||||
@@ -27,6 +27,7 @@ class Simulation(SimComponent):
|
||||
rm.add_request("network", RequestType(func=self.network._request_manager))
|
||||
# pass through domain requests to the domain object
|
||||
rm.add_request("domain", RequestType(func=self.domain._request_manager))
|
||||
rm.add_request("do_nothing", RequestType(func=lambda request, context: ()))
|
||||
return rm
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
from src.primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
|
||||
class ApplicationOperatingState(Enum):
|
||||
@@ -51,7 +51,7 @@ class Application(IOSoftware):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"opearting_state": self.operating_state.name,
|
||||
"opearting_state": self.operating_state.value,
|
||||
"execution_control_status": self.execution_control_status,
|
||||
"num_executions": self.num_executions,
|
||||
"groups": list(self.groups),
|
||||
|
||||
@@ -4,10 +4,10 @@ from uuid import uuid4
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from src.primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
|
||||
@@ -2,11 +2,11 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from primaite.simulator.network.protocols.http import HttpRequestMethod, HttpRequestPacket, HttpResponsePacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from src.primaite.simulator.network.protocols.http import HttpRequestMethod, HttpRequestPacket, HttpResponsePacket
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.applications.application import Application
|
||||
from src.primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
|
||||
|
||||
class WebBrowser(Application):
|
||||
@@ -38,7 +38,8 @@ class WebBrowser(Application):
|
||||
|
||||
:return: A dictionary capturing the current state of the WebBrowser and its child objects.
|
||||
"""
|
||||
return super().describe_state()
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from src.primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class _JSONFilter(logging.Filter):
|
||||
|
||||
@@ -5,15 +5,15 @@ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.core import SimComponent
|
||||
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
from src.primaite.simulator.core import SimComponent
|
||||
from src.primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPPacket, IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port, TCPHeader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.network.hardware.base import ARPCache
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.network.hardware.base import ARPCache
|
||||
from src.primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
|
||||
class Session(SimComponent):
|
||||
|
||||
@@ -3,17 +3,18 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.simulator.system.software import IOSoftware
|
||||
from src.primaite.simulator.file_system.file_system import FileSystem
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.applications.application import Application, ApplicationOperatingState
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from src.primaite.simulator.system.software import IOSoftware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.simulator.system.core.session_manager import SessionManager
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.system.core.session_manager import SessionManager
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.network.hardware.base import Node
|
||||
|
||||
from typing import Type, TypeVar
|
||||
|
||||
@@ -25,6 +26,7 @@ class SoftwareManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_node: "Node",
|
||||
session_manager: "SessionManager",
|
||||
sys_log: SysLog,
|
||||
file_system: FileSystem,
|
||||
@@ -35,6 +37,7 @@ class SoftwareManager:
|
||||
|
||||
:param session_manager: The session manager handling network communications.
|
||||
"""
|
||||
self.node = parent_node
|
||||
self.session_manager = session_manager
|
||||
self.software: Dict[str, Union[Service, Application]] = {}
|
||||
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
|
||||
@@ -62,6 +65,8 @@ class SoftwareManager:
|
||||
|
||||
:param software_class: The software class.
|
||||
"""
|
||||
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
|
||||
# separation of concerns.
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
@@ -77,6 +82,12 @@ class SoftwareManager:
|
||||
if isinstance(software, Application):
|
||||
software.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
# add the software to the node's registry after it has been fully initialized
|
||||
if isinstance(software, Service):
|
||||
self.node.install_service(software)
|
||||
elif isinstance(software, Application):
|
||||
self.node.install_application(software)
|
||||
|
||||
def uninstall(self, software_name: str):
|
||||
"""
|
||||
Uninstall an Application or Service.
|
||||
@@ -85,6 +96,10 @@ class SoftwareManager:
|
||||
"""
|
||||
if software_name in self.software:
|
||||
software = self.software.pop(software_name) # noqa
|
||||
if isinstance(software, Application):
|
||||
self.node.uninstall_application(software)
|
||||
elif isinstance(software, Service):
|
||||
self.node.uninstall_service(software)
|
||||
del software
|
||||
self.sys_log.info(f"Deleted {software_name}")
|
||||
return
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from src.primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class _NotJSONFilter(logging.Filter):
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from primaite.simulator.system.software import Software
|
||||
from src.primaite.simulator.system.software import Software
|
||||
|
||||
|
||||
class ProcessOperatingState(Enum):
|
||||
|
||||
@@ -6,13 +6,13 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
from src.primaite.simulator.file_system.file_system import File
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from src.primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from src.primaite.simulator.system.services.service import Service, ServiceOperatingState
|
||||
from src.primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
class DatabaseService(Service):
|
||||
|
||||
@@ -2,11 +2,11 @@ from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from src.primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from src.primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from src.primaite.simulator.network.protocols.dns import DNSPacket
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.services.service import Service
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from src.primaite.simulator.file_system.file_system import File
|
||||
from src.primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
from src.primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from src.primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
class FTPClient(FTPServiceABC):
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from src.primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
|
||||
from src.primaite.simulator.system.services.service import ServiceOperatingState
|
||||
|
||||
|
||||
class FTPServer(FTPServiceABC):
|
||||
|
||||
@@ -3,10 +3,10 @@ from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from src.primaite.simulator.file_system.file_system import File
|
||||
from src.primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class FTPServiceABC(Service, ABC):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Optional
|
||||
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from src.primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
|
||||
class DataManipulationBot(DatabaseClient):
|
||||
|
||||
@@ -2,8 +2,8 @@ from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
from src.primaite.simulator.core import RequestManager, RequestType
|
||||
from src.primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -15,14 +15,14 @@ class ServiceOperatingState(Enum):
|
||||
"The service is currently running."
|
||||
STOPPED = 2
|
||||
"The service is not running."
|
||||
INSTALLING = 3
|
||||
"The service is being installed or updated."
|
||||
RESTARTING = 4
|
||||
"The service is in the process of restarting."
|
||||
PAUSED = 5
|
||||
PAUSED = 3
|
||||
"The service is temporarily paused."
|
||||
DISABLED = 6
|
||||
DISABLED = 4
|
||||
"The service is disabled and cannot be started."
|
||||
INSTALLING = 5
|
||||
"The service is being installed or updated."
|
||||
RESTARTING = 6
|
||||
"The service is in the process of restarting."
|
||||
|
||||
|
||||
class Service(IOSoftware):
|
||||
@@ -68,7 +68,7 @@ class Service(IOSoftware):
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["operating_state"] = self.operating_state.name
|
||||
state["operating_state"] = self.operating_state.value
|
||||
state["health_state_actual"] = self.health_state_actual
|
||||
state["health_state_visible"] = self.health_state_visible
|
||||
return state
|
||||
|
||||
@@ -1,22 +1,39 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from primaite.simulator.network.protocols.http import (
|
||||
from src.primaite.simulator.network.protocols.http import (
|
||||
HttpRequestMethod,
|
||||
HttpRequestPacket,
|
||||
HttpResponsePacket,
|
||||
HttpStatusCode,
|
||||
)
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from src.primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from src.primaite.simulator.system.services.service import Service
|
||||
|
||||
|
||||
class WebServer(Service):
|
||||
"""Class used to represent a Web Server Service in simulation."""
|
||||
|
||||
last_response_status_code: Optional[HttpStatusCode] = None
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
|
||||
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
|
||||
|
||||
:return: Current state of this object and child objects.
|
||||
:rtype: Dict
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["last_response_status_code"] = (
|
||||
self.last_response_status_code.value if self.last_response_status_code else None
|
||||
)
|
||||
return state
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "WebServer"
|
||||
kwargs["protocol"] = IPProtocol.TCP
|
||||
@@ -66,6 +83,7 @@ class WebServer(Service):
|
||||
self.send(payload=response, session_id=session_id)
|
||||
|
||||
# return true if response is OK
|
||||
self.last_response_status_code = response.status_code
|
||||
return response.status_code == HttpStatusCode.OK
|
||||
|
||||
def _handle_get_request(self, payload: HttpRequestPacket) -> HttpResponsePacket:
|
||||
|
||||
@@ -3,11 +3,11 @@ from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.core.session_manager import Session
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from src.primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from src.primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from src.primaite.simulator.network.transmission.transport_layer import Port
|
||||
from src.primaite.simulator.system.core.session_manager import Session
|
||||
from src.primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
|
||||
class SoftwareType(Enum):
|
||||
@@ -121,9 +121,9 @@ class Software(SimComponent):
|
||||
state = super().describe_state()
|
||||
state.update(
|
||||
{
|
||||
"health_state": self.health_state_actual.name,
|
||||
"health_state_red_view": self.health_state_visible.name,
|
||||
"criticality": self.criticality.name,
|
||||
"health_state": self.health_state_actual.value,
|
||||
"health_state_red_view": self.health_state_visible.value,
|
||||
"criticality": self.criticality.value,
|
||||
"patching_count": self.patching_count,
|
||||
"scanning_count": self.scanning_count,
|
||||
"revealed_to_red": self.revealed_to_red,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user