Merge remote-tracking branch 'origin/dev' into bugfix/1587-hardcoded-agent

This commit is contained in:
Marek Wolan
2023-07-09 18:07:30 +01:00
22 changed files with 879 additions and 505 deletions

4
.gitignore vendored
View File

@@ -138,4 +138,8 @@ dmypy.json
# Cython debug symbols
cython_debug/
# IDE
.idea/
# outputs
src/primaite/outputs/

View File

@@ -1 +1,64 @@
# PrimAITE
## Getting Started with PrimAITE
### Pre-Requisites
In order to get **PrimAITE** installed, you will need to have the following installed:
- `python3.8+`
- `python3-pip`
- `virtualenv`
**PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS.
### Installation from source
#### 1. Navigate to the PrimAITE folder and create a new python virtual environment (venv)
```unix
python3 -m venv <name_of_venv>
```
#### 2. Activate the venv
##### Unix
```bash
source <name_of_venv>/bin/activate
```
##### Windows
```powershell
.\<name_of_venv>\Scripts\activate
```
#### 3. Install `primaite` into the venv along with all of it's dependencies
```bash
python3 -m pip install -e .
```
### Development Installation
To install the development dependencies, postfix the command in step 3 above with the `[dev]` extra. Example:
```bash
python3 -m pip install -e .[dev]
```
## Building documentation
The PrimAITE documentation can be built with the following commands:
##### Unix
```bash
cd docs
make html
```
##### Windows
```powershell
cd docs
.\make.bat html
```
This will build the documentation as a collection of HTML files which uses the Read The Docs sphinx theme. Other build
options are available but may require additional dependencies such as LaTeX and PDF. Please refer to the Sphinx documentation
for your specific output requirements.

View File

@@ -82,203 +82,203 @@ The environment config file consists of the following attributes:
Rewards are calculated based on the difference between the current state and reference state (the 'should be' state) of the environment.
* **Generic [all_ok]** [int]
* **Generic [all_ok]** [float]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [int]
* **Node Hardware State [off_should_be_on]** [float]
The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [int]
* **Node Hardware State [off_should_be_resetting]** [float]
The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [int]
* **Node Hardware State [on_should_be_off]** [float]
The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [int]
* **Node Hardware State [on_should_be_resetting]** [float]
The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [int]
* **Node Hardware State [resetting_should_be_on]** [float]
The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [int]
* **Node Hardware State [resetting_should_be_off]** [float]
The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [int]
* **Node Hardware State [resetting]** [float]
The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [int]
* **Node Operating System or Service State [good_should_be_patching]** [float]
The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [int]
* **Node Operating System or Service State [good_should_be_compromised]** [float]
The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
* **Node Operating System or Service State [good_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [int]
* **Node Operating System or Service State [patching_should_be_good]** [float]
The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
* **Node Operating System or Service State [patching_should_be_compromised]** [float]
The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [int]
* **Node Operating System or Service State [patching]** [float]
The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [int]
* **Node Operating System or Service State [compromised_should_be_good]** [float]
The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
* **Node Operating System or Service State [compromised_should_be_patching]** [float]
The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [float]
The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [int]
* **Node Operating System or Service State [compromised]** [float]
The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
* **Node Operating System or Service State [overwhelmed_should_be_good]** [float]
The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [float]
The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [float]
The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [int]
* **Node Operating System or Service State [overwhelmed]** [float]
The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [int]
* **Node File System State [good_should_be_repairing]** [float]
The score to give when the state should be repairing, but is good
* **Node File System State [good_should_be_restoring]** [int]
* **Node File System State [good_should_be_restoring]** [float]
The score to give when the state should be restoring, but is good
* **Node File System State [good_should_be_corrupt]** [int]
* **Node File System State [good_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is good
* **Node File System State [good_should_be_destroyed]** [int]
* **Node File System State [good_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is good
* **Node File System State [repairing_should_be_good]** [int]
* **Node File System State [repairing_should_be_good]** [float]
The score to give when the state should be good, but is repairing
* **Node File System State [repairing_should_be_restoring]** [int]
* **Node File System State [repairing_should_be_restoring]** [float]
The score to give when the state should be restoring, but is repairing
* **Node File System State [repairing_should_be_corrupt]** [int]
* **Node File System State [repairing_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is repairing
* **Node File System State [repairing_should_be_destroyed]** [int]
* **Node File System State [repairing_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is repairing
* **Node File System State [repairing]** [int]
* **Node File System State [repairing]** [float]
The score to give when the state is repairing
* **Node File System State [restoring_should_be_good]** [int]
* **Node File System State [restoring_should_be_good]** [float]
The score to give when the state should be good, but is restoring
* **Node File System State [restoring_should_be_repairing]** [int]
* **Node File System State [restoring_should_be_repairing]** [float]
The score to give when the state should be repairing, but is restoring
* **Node File System State [restoring_should_be_corrupt]** [int]
* **Node File System State [restoring_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is restoring
* **Node File System State [restoring_should_be_destroyed]** [int]
* **Node File System State [restoring_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is restoring
* **Node File System State [restoring]** [int]
* **Node File System State [restoring]** [float]
The score to give when the state is restoring
* **Node File System State [corrupt_should_be_good]** [int]
* **Node File System State [corrupt_should_be_good]** [float]
The score to give when the state should be good, but is corrupt
* **Node File System State [corrupt_should_be_repairing]** [int]
* **Node File System State [corrupt_should_be_repairing]** [float]
The score to give when the state should be repairing, but is corrupt
* **Node File System State [corrupt_should_be_restoring]** [int]
* **Node File System State [corrupt_should_be_restoring]** [float]
The score to give when the state should be restoring, but is corrupt
* **Node File System State [corrupt_should_be_destroyed]** [int]
* **Node File System State [corrupt_should_be_destroyed]** [float]
The score to give when the state should be destroyed, but is corrupt
* **Node File System State [corrupt]** [int]
* **Node File System State [corrupt]** [float]
The score to give when the state is corrupt
* **Node File System State [destroyed_should_be_good]** [int]
* **Node File System State [destroyed_should_be_good]** [float]
The score to give when the state should be good, but is destroyed
* **Node File System State [destroyed_should_be_repairing]** [int]
* **Node File System State [destroyed_should_be_repairing]** [float]
The score to give when the state should be repairing, but is destroyed
* **Node File System State [destroyed_should_be_restoring]** [int]
* **Node File System State [destroyed_should_be_restoring]** [float]
The score to give when the state should be restoring, but is destroyed
* **Node File System State [destroyed_should_be_corrupt]** [int]
* **Node File System State [destroyed_should_be_corrupt]** [float]
The score to give when the state should be corrupt, but is destroyed
* **Node File System State [destroyed]** [int]
* **Node File System State [destroyed]** [float]
The score to give when the state is destroyed
* **Node File System State [scanning]** [int]
* **Node File System State [scanning]** [float]
The score to give when the state is scanning
* **IER Status [red_ier_running]** [int]
* **IER Status [red_ier_running]** [float]
The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [int]
* **IER Status [green_ier_blocked]** [float]
The score to give when a green agent IER is prevented from running
@@ -308,6 +308,14 @@ Rewards are calculated based on the difference between the current state and ref
The number of steps to take when scanning the file system
* **deterministic** [bool]
Set to true if the agent evaluation should be deterministic. Default is ``False``
* **seed** [int]
Seed used in the randomisation in agent training. Default is ``None``
The Lay Down Config
*******************

View File

@@ -10,19 +10,19 @@ class AccessControlList:
def __init__(self):
"""Init."""
self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules
self.acl: Dict[str, ACLRule] = {} # A dictionary of ACL Rules
def check_address_match(self, _rule, _source_ip_address, _dest_ip_address):
"""
Checks for IP address matches.
def check_address_match(self, _rule: ACLRule, _source_ip_address: str, _dest_ip_address: str) -> bool:
"""Checks for IP address matches.
Args:
_rule: The rule being checked
_source_ip_address: the source IP address to compare
_dest_ip_address: the destination IP address to compare
Returns:
True if match; False otherwise.
:param _rule: The rule object to check
:type _rule: ACLRule
:param _source_ip_address: Source IP address to compare
:type _source_ip_address: str
:param _dest_ip_address: Destination IP address to compare
:type _dest_ip_address: str
:return: True if there is a match, otherwise False.
:rtype: bool
"""
if (
(_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address)
@@ -34,7 +34,7 @@ class AccessControlList:
else:
return False
def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port):
def is_blocked(self, _source_ip_address: str, _dest_ip_address: str, _protocol: str, _port: str) -> bool:
"""
Checks for rules that block a protocol / port.

View File

@@ -257,10 +257,19 @@ class AgentSessionABC(ABC):
raise FileNotFoundError(msg)
pass
@property
def _saved_agent_path(self) -> Path:
file_name = (
f"{self._training_config.agent_framework}_"
f"{self._training_config.agent_identifier}_"
f"{self.timestamp_str}.zip"
)
return self.learning_path / file_name
@abstractmethod
def save(self):
"""Save the agent."""
self._agent.save(self.session_path)
pass
@abstractmethod
def export(self):

View File

@@ -1,5 +1,9 @@
from typing import Any, Dict, List, Union
import numpy as np
from primaite.acl.access_control_list import AccessControlList
from primaite.acl.acl_rule import ACLRule
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import (
get_new_action,
@@ -7,13 +11,17 @@ from primaite.agents.utils import (
transform_action_acl_enum,
transform_change_obs_readable,
)
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import HardCodedAgentView
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
from primaite.pol.ier import IER
class HardCodedACLAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic ACL agent."""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC:
# Basic view action using only the current observation
return self._calculate_action_basic_view(obs)
@@ -22,12 +30,19 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
# history and reward feedback
return self._calculate_action_full_view(obs)
def get_blocked_green_iers(self, green_iers, acl, nodes):
"""
Get blocked green IERs.
def get_blocked_green_iers(
self, green_iers: Dict[str, IER], acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[Any, Any]:
"""Get blocked green IERs.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param green_iers: Green IERs to check for being
:type green_iers: Dict[str, IER]
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: Same as `green_iers` input dict, but filtered to only contain the blocked ones.
:rtype: Dict[str, IER]
"""
blocked_green_iers = {}
@@ -45,12 +60,17 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_green_iers
def get_matching_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get matching ACL rules for an IER.
def get_matching_acl_rules_for_ier(self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]):
"""Get list of ACL rules which are relevant to an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
source_node_id = ier.get_source_node_id()
source_node_address = nodes[source_node_id].ip_address
@@ -58,11 +78,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
dest_node_address = nodes[dest_node_id].ip_address
protocol = ier.get_protocol() # e.g. 'TCP'
port = ier.get_port()
matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port)
return matching_rules
def get_blocking_acl_rules_for_ier(self, ier, acl, nodes):
def get_blocking_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
"""
Get blocking ACL rules for an IER.
@@ -70,8 +91,14 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
Can return empty dict but IER can still be blocked by default
(No ALLOW rule, therefore blocked).
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
@@ -82,12 +109,19 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return blocked_rules
def get_allow_acl_rules_for_ier(self, ier, acl, nodes):
"""
Get all allowing ACL rules for an IER.
def get_allow_acl_rules_for_ier(
self, ier: IER, acl: AccessControlList, nodes: Dict[str, NodeUnion]
) -> Dict[str, Any]:
"""Get all allowing ACL rules for an IER.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ier: Information Exchange Request to query against the ACL list
:type ier: IER
:param acl: Firewall rules
:type acl: AccessControlList
:param nodes: Nodes in the network
:type nodes: Dict[str,NodeUnion]
:return: _description_
:rtype: _type_
"""
matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes)
@@ -100,19 +134,32 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_matching_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get matching ACL rules.
source_node_id: str,
dest_node_id: str,
protocol: str,
port: str,
acl: AccessControlList,
nodes: Dict[str, Union[ServiceNode, ActiveNode]],
services_list: List[str],
) -> Dict[str, ACLRule]:
"""Filter ACL rules to only those which are relevant to the specified nodes.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param source_node_id: Source node
:type source_node_id: str
:param dest_node_id: Destination nodes
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: str
:param port: Network port
:type port: str
:param acl: Access Control list which will be filtered
:type acl: AccessControlList
:param nodes: The environment's node directory.
:type nodes: Dict[str, Union[ServiceNode, ActiveNode]]
:param services_list: List of services registered for the environment.
:type services_list: List[str]
:return: Filtered version of 'acl'
:rtype: Dict[str, ACLRule]
"""
if source_node_id != "ANY":
source_node_address = nodes[str(source_node_id)].ip_address
@@ -132,19 +179,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_allow_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the ALLOW ACL rules.
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
"""List ALLOW rules relating to specified nodes.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
@@ -165,19 +226,33 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
def get_deny_acl_rules(
self,
source_node_id,
dest_node_id,
protocol,
port,
acl,
nodes,
services_list,
):
"""
Get the DENY ACL rules.
source_node_id: int,
dest_node_id: str,
protocol: int,
port: str,
acl: AccessControlList,
nodes: Dict[str, NodeUnion],
services_list: List[str],
) -> Dict[str, ACLRule]:
"""List DENY rules relating to specified nodes.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param source_node_id: Source node id
:type source_node_id: int
:param dest_node_id: Destination node
:type dest_node_id: str
:param protocol: Network protocol
:type protocol: int
:param port: Port
:type port: str
:param acl: Firewall ruleset which is applied to the network
:type acl: AccessControlList
:param nodes: The simulation's node store
:type nodes: Dict[str, NodeUnion]
:param services_list: Services list
:type services_list: List[str]
:return: Filtered ACL Rule directory which includes only those rules which affect the specified source and
desination nodes
:rtype: Dict[str, ACLRule]
"""
matching_rules = self.get_matching_acl_rules(
source_node_id,
@@ -196,7 +271,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
return allowed_rules
def _calculate_action_full_view(self, obs):
def _calculate_action_full_view(self, obs: np.ndarray) -> int:
"""
Calculate a good acl-based action for the blue agent to take.
@@ -224,8 +299,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing
an overwhelmed state, so we don't do this.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
# obs = convert_to_old_obs(obs)
r_obs = transform_change_obs_readable(obs)
@@ -361,7 +438,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
action = get_new_action(action, self._env.action_dict)
return action
def _calculate_action_basic_view(self, obs):
def _calculate_action_basic_view(self, obs: np.ndarray) -> int:
"""Calculate a good acl-based action for the blue agent to take.
Uses ONLY information from the current observation with NO knowledge
@@ -379,8 +456,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC):
Currently, a deny rule does not overwrite an allow rule. The allow
rules must be deleted.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)

View File

@@ -1,3 +1,5 @@
import numpy as np
from primaite.agents.agent import HardCodedAgentSessionABC
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
@@ -5,12 +7,14 @@ from primaite.agents.utils import get_new_action, transform_action_node_enum, tr
class HardCodedNodeAgent(HardCodedAgentSessionABC):
"""An Agent Session class that implements a deterministic Node agent."""
def _calculate_action(self, obs):
def _calculate_action(self, obs: np.ndarray) -> int:
"""
Calculate a good node-based action for the blue agent to take.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: current observation from the gym environment
:type obs: np.ndarray
:return: Optimal action to take in the environment (chosen from the discrete action space)
:rtype: int
"""
action_dict = self._env.action_dict
r_obs = transform_change_obs_readable(obs)

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import json
import shutil
from datetime import datetime
from pathlib import Path
from typing import Union
from uuid import uuid4
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.a2c import A2CConfig
@@ -106,6 +108,7 @@ class RLlibAgent(AgentSessionABC):
timestamp_str=self.timestamp_str,
),
)
self._agent_config.seed = self._training_config.seed
self._agent_config.training(train_batch_size=self._training_config.num_steps)
self._agent_config.framework(framework="tf")
@@ -120,9 +123,11 @@ class RLlibAgent(AgentSessionABC):
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
self._agent.save(str(self.checkpoints_path))
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
self._agent.save(str(self.checkpoints_path))
def learn(
self,
@@ -140,9 +145,14 @@ class RLlibAgent(AgentSessionABC):
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self.save()
self._agent.stop()
super().learn()
# save agent
self.save()
def evaluate(
self,
**kwargs,
@@ -162,9 +172,25 @@ class RLlibAgent(AgentSessionABC):
"""Load an agent from file."""
raise NotImplementedError
def save(self):
def save(self, overwrite_existing: bool = True):
"""Save the agent."""
raise NotImplementedError
# Make temp dir to save in isolation
temp_dir = self.learning_path / str(uuid4())
temp_dir.mkdir()
# Save the agent to the temp dir
self._agent.save(str(temp_dir))
# Capture the saved Rllib checkpoint inside the temp directory
for file in temp_dir.iterdir():
checkpoint_dir = file
break
# Zip the folder
shutil.make_archive(str(self._saved_agent_path).replace(".zip", ""), "zip", checkpoint_dir) # noqa
# Drop the temp directory
shutil.rmtree(temp_dir)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -59,16 +59,19 @@ class SB3Agent(AgentSessionABC):
verbose=self.sb3_output_verbose_level,
n_steps=self._training_config.num_steps,
tensorboard_log=str(self._tensorboard_log_path),
seed=self._training_config.seed,
)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._env.episode_count
if checkpoint_n > 0 and episode_count > 0:
if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes):
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
save_checkpoint = False
if checkpoint_n:
save_checkpoint = episode_count % checkpoint_n == 0
if episode_count and save_checkpoint:
checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip"
self._agent.save(checkpoint_path)
_LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}")
def _get_latest_checkpoint(self):
pass
@@ -90,25 +93,27 @@ class SB3Agent(AgentSessionABC):
self._agent.learn(total_timesteps=time_steps)
self._save_checkpoint()
self._env.reset()
self.save()
self._env.close()
super().learn()
# save agent
self.save()
def evaluate(
self,
deterministic: bool = True,
**kwargs,
):
"""
Evaluate the agent.
:param deterministic: Whether the evaluation is deterministic.
:param kwargs: Any agent-specific key-word args to be passed.
"""
time_steps = self._training_config.num_steps
episodes = self._training_config.num_episodes
self._env.set_as_eval()
self.is_eval = True
if deterministic:
if self._training_config.deterministic:
deterministic_str = "deterministic"
else:
deterministic_str = "non-deterministic"
@@ -119,7 +124,7 @@ class SB3Agent(AgentSessionABC):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(obs, deterministic=deterministic)
action, _states = self._agent.predict(obs, deterministic=self._training_config.deterministic)
if isinstance(action, np.ndarray):
action = np.int64(action)
obs, rewards, done, info = self._env.step(action)
@@ -134,7 +139,7 @@ class SB3Agent(AgentSessionABC):
def save(self):
"""Save the agent."""
raise NotImplementedError
self._agent.save(self._saved_agent_path)
def export(self):
"""Export the agent to transportable file format."""

View File

@@ -1,5 +1,8 @@
from typing import Dict, List, Union
import numpy as np
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
HardwareState,
LinkStatus,
@@ -10,15 +13,17 @@ from primaite.common.enums import (
)
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
def transform_action_node_readable(action: List[int]) -> List[Union[int, str]]:
"""Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_node_property = NodePOLType(action[1]).name
@@ -33,15 +38,18 @@ def transform_action_node_readable(action):
return new_action
def transform_action_acl_readable(action):
def transform_action_acl_readable(action: List[str]) -> List[Union[str, int]]:
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: The same action list, but with the encodings translated back into meaningful labels
:rtype: List[Union[int,str]]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
@@ -58,7 +66,7 @@ def transform_action_acl_readable(action):
return new_action
def is_valid_node_action(action):
def is_valid_node_action(action: List[int]) -> bool:
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
@@ -67,8 +75,11 @@ def is_valid_node_action(action):
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_node_readable(action)
@@ -93,7 +104,7 @@ def is_valid_node_action(action):
return True
def is_valid_acl_action(action):
def is_valid_acl_action(action: List[int]) -> bool:
"""
Is the ACL action an actual valid action.
@@ -103,8 +114,11 @@ def is_valid_acl_action(action):
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
action_r = transform_action_acl_readable(action)
@@ -126,12 +140,15 @@ def is_valid_acl_action(action):
return True
def is_valid_acl_action_extra(action):
def is_valid_acl_action_extra(action: List[int]) -> bool:
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Agent action, formatted as a list of ints, for more information check out
`primaite.environment.primaite_env.Primaite`
:type action: List[int]
:return: Whether the action is valid
:rtype: bool
"""
if is_valid_acl_action(action) is False:
return False
@@ -150,15 +167,16 @@ def is_valid_acl_action_extra(action):
return True
def transform_change_obs_readable(obs):
"""
Transform list of transactions to readable list of each observation property.
def transform_change_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform list of transactions to readable list of each observation property.
example:
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
ids = [i for i in obs[:, 0]]
operating_states = [HardwareState(i).name for i in obs[:, 1]]
@@ -173,14 +191,16 @@ def transform_change_obs_readable(obs):
return new_obs
def transform_obs_readable(obs):
"""
Transform observation to readable format.
def transform_obs_readable(obs: np.ndarray) -> List[List[Union[str, int]]]:
"""Transform observation to readable format.
example
np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: Raw observation from the environment.
:type obs: np.ndarray
:return: The same observation, but the encoded integer values are replaced with readable names.
:rtype: List[List[Union[str, int]]]
"""
changed_obs = transform_change_obs_readable(obs)
new_obs = list(zip(*changed_obs))
@@ -190,21 +210,23 @@ def transform_obs_readable(obs):
return new_obs
def convert_to_new_obs(obs, num_nodes=10):
"""
Convert original gym Box observation space to new multiDiscrete observation space.
def convert_to_new_obs(obs: np.ndarray, num_nodes: int = 10) -> np.ndarray:
"""Convert original gym Box observation space to new multiDiscrete observation space.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: observation in the 'old' (NodeLinkTable) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:return: reformatted observation
:rtype: np.ndarray
"""
# Remove ID columns, remove links and flatten to MultiDiscrete observation space
new_obs = obs[:num_nodes, 1:].flatten()
return new_obs
def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
"""
Convert to old observation.
def convert_to_old_obs(obs: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1) -> np.ndarray:
"""Convert to old observation.
Links filled with 0's as no information is included in new observation space.
@@ -216,8 +238,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
[ 3, 1, 1, 1],
...
[20, 0, 0, 0]])
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs: observation in the 'new' (MultiDiscrete) format
:type obs: np.ndarray
:param num_nodes: number of nodes in the network, defaults to 10
:type num_nodes: int, optional
:param num_links: number of links in the network, defaults to 10
:type num_links: int, optional
:param num_services: number of services on the network, defaults to 1
:type num_services: int, optional
:return: 2-d BOX observation space, in the same format as NodeLinkTable
:rtype: np.ndarray
"""
# Convert back to more readable, original format
reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2)
@@ -239,17 +270,28 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1):
return new_obs
def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
"""
Return string describing change between two observations.
def describe_obs_change(
obs1: np.ndarray, obs2: np.ndarray, num_nodes: int = 10, num_links: int = 10, num_services: int = 1
) -> str:
"""Build a string describing the difference between two observations.
example:
obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]])
obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]])
output = 'ID 1: SERVICE 2 set to GOOD'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs1: First observation
:type obs1: np.ndarray
:param obs2: Second observation
:type obs2: np.ndarray
:param num_nodes: How many nodes are in the network laydown, defaults to 10
:type num_nodes: int, optional
:param num_links: How many links are in the network laydown, defaults to 10
:type num_links: int, optional
:param num_services: How many services are configured for this scenario, defaults to 1
:type num_services: int, optional
:return: A multi-line string with a human-readable description of the difference.
:rtype: str
"""
obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services)
obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services)
@@ -268,7 +310,7 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1):
return change_string
def _describe_obs_change_helper(obs_change, is_link):
def _describe_obs_change_helper(obs_change: List[int], is_link: bool) -> str:
"""
Helper funcion to describe what has changed.
@@ -277,8 +319,14 @@ def _describe_obs_change_helper(obs_change, is_link):
Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.'
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param obs_change: List of integers generated within the `describe_obs_change` function. It should correspond to one
row of the observation table, and have `-1` at locations where the observation hasn't changed, and the new
status where it has changed.
:type obs_change: List[int]
:param is_link: Whether the row of the observation space corresponds to a link. False means it represents a node.
:type is_link: bool
:return: A human-readable description of the difference between the two observation rows.
:rtype: str
"""
# Indexes where a change has occured, not including 0th index
index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1]
@@ -304,15 +352,15 @@ def _describe_obs_change_helper(obs_change, is_link):
return desc
def transform_action_node_enum(action):
"""
Convert a node action from readable string format, to enumerated format.
def transform_action_node_enum(action: List[Union[str, int]]) -> List[int]:
"""Convert a node action from readable string format, to enumerated format.
example:
[1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: Action in 'readable' format
:type action: List[Union[str,int]]
:return: Action with verbs encoded as ints
:rtype: List[int]
"""
action_node_id = action[0]
action_node_property = NodePOLType[action[1]].value
@@ -336,63 +384,14 @@ def transform_action_node_enum(action):
return new_action
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
return new_action
def node_action_description(action):
"""
Generate string describing a node-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if isinstance(action[1], (int, np.int64)):
# transform action to readable format
action = transform_action_node_readable(action)
node_id = action[0]
node_property = action[1]
property_action = action[2]
service_id = action[3]
if property_action == "NONE":
return ""
if node_property == "OPERATING" or node_property == "OS":
description = f"NODE {node_id}, {node_property}, SET TO {property_action}"
elif node_property == "SERVICE":
description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}"
else:
return ""
return description
def transform_action_acl_enum(action):
def transform_action_acl_enum(action: List[Union[int, str]]) -> np.ndarray:
"""
Convert acl action from readable str format, to enumerated format.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param action: ACL-based action expressed as a list of human-readable ints and strings
:type action: List[Union[int,str]]
:return: The same action but encoded to contain only integers.
:rtype: np.ndarray
"""
action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2}
action_permissions = {"DENY": 0, "ALLOW": 1}
@@ -410,35 +409,17 @@ def transform_action_acl_enum(action):
return new_action
def acl_action_description(action):
"""
Generate string describing an acl-based action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if isinstance(action[0], (int, np.int64)):
# transform action to readable format
action = transform_action_acl_readable(action)
if action[0] == "NONE":
description = "NO ACL RULE APPLIED"
else:
description = (
f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]},"
f" for protocol/service index {action[4]} on port index {action[5]}"
)
return description
def get_node_of_ip(ip, node_dict):
"""
Get the node ID of an IP address.
def get_node_of_ip(ip: str, node_dict: Dict[str, NodeUnion]) -> str:
"""Get the node ID of an IP address.
node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param ip: The IP address of the node whose ID is required
:type ip: str
:param node_dict: The environment's node registry dictionary
:type node_dict: Dict[str,NodeUnion]
:return: The key from the registry dict that corresponds to the node with the IP adress provided by `ip`
:rtype: str
"""
for node_key, node_value in node_dict.items():
node_ip = node_value.ip_address
@@ -446,104 +427,18 @@ def get_node_of_ip(ip, node_dict):
return node_key
def is_valid_node_action(action):
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in [
"NONE",
"PATCHING",
]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY":
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action):
"""
Harsher version of valid acl actions, does not allow action.
TODO: Add params and return in docstring.
TODO: Typehint params and return.
"""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True
def get_new_action(old_action, action_dict):
def get_new_action(old_action: np.ndarray, action_dict: Dict[int, List]) -> int:
"""
Get new action (e.g. 32) from old action e.g. [1,1,1,0].
Old_action can be either node or acl action type
TODO: Add params and return in docstring.
TODO: Typehint params and return.
:param old_action: Action expressed as a list of choices, eg. [1,1,1,0]
:type old_action: np.ndarray
:param action_dict: Dictionary for translating the multidiscrete actions into the list-based actions.
:type action_dict: Dict[int,List]
:return: Action key correspoinding to the input `old_action`
:rtype: int
"""
for key, val in action_dict.items():
if list(val) == list(old_action):

View File

@@ -31,6 +31,16 @@ agent_identifier: PPO
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: null
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
@@ -83,58 +93,58 @@ sb3_output_verbose_level: NONE
# Generic
all_ok: 0
# Node Hardware State
off_should_be_on: -10
off_should_be_resetting: -5
on_should_be_off: -2
on_should_be_resetting: -5
resetting_should_be_on: -5
resetting_should_be_off: -2
resetting: -3
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 2
good_should_be_compromised: 5
good_should_be_overwhelmed: 5
patching_should_be_good: -5
patching_should_be_compromised: 2
patching_should_be_overwhelmed: 2
patching: -3
compromised_should_be_good: -20
compromised_should_be_patching: -20
compromised_should_be_overwhelmed: -20
compromised: -20
overwhelmed_should_be_good: -20
overwhelmed_should_be_patching: -20
overwhelmed_should_be_compromised: -20
overwhelmed: -20
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 2
good_should_be_restoring: 2
good_should_be_corrupt: 5
good_should_be_destroyed: 10
repairing_should_be_good: -5
repairing_should_be_restoring: 2
repairing_should_be_corrupt: 2
repairing_should_be_destroyed: 0
repairing: -3
restoring_should_be_good: -10
restoring_should_be_repairing: -2
restoring_should_be_corrupt: 1
restoring_should_be_destroyed: 2
restoring: -6
corrupt_should_be_good: -10
corrupt_should_be_repairing: -10
corrupt_should_be_restoring: -10
corrupt_should_be_destroyed: 2
corrupt: -10
destroyed_should_be_good: -20
destroyed_should_be_repairing: -20
destroyed_should_be_restoring: -20
destroyed_should_be_corrupt: -20
destroyed: -20
scanning: -2
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -5
green_ier_blocked: -10
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS

View File

@@ -94,64 +94,64 @@ class TrainingConfig:
# Reward values
# Generic
all_ok: int = 0
all_ok: float = 0
# Node Hardware State
off_should_be_on: int = -10
off_should_be_resetting: int = -5
on_should_be_off: int = -2
on_should_be_resetting: int = -5
resetting_should_be_on: int = -5
resetting_should_be_off: int = -2
resetting: int = -3
off_should_be_on: float = -0.001
off_should_be_resetting: float = -0.0005
on_should_be_off: float = -0.0002
on_should_be_resetting: float = -0.0005
resetting_should_be_on: float = -0.0005
resetting_should_be_off: float = -0.0002
resetting: float = -0.0003
# Node Software or Service State
good_should_be_patching: int = 2
good_should_be_compromised: int = 5
good_should_be_overwhelmed: int = 5
patching_should_be_good: int = -5
patching_should_be_compromised: int = 2
patching_should_be_overwhelmed: int = 2
patching: int = -3
compromised_should_be_good: int = -20
compromised_should_be_patching: int = -20
compromised_should_be_overwhelmed: int = -20
compromised: int = -20
overwhelmed_should_be_good: int = -20
overwhelmed_should_be_patching: int = -20
overwhelmed_should_be_compromised: int = -20
overwhelmed: int = -20
good_should_be_patching: float = 0.0002
good_should_be_compromised: float = 0.0005
good_should_be_overwhelmed: float = 0.0005
patching_should_be_good: float = -0.0005
patching_should_be_compromised: float = 0.0002
patching_should_be_overwhelmed: float = 0.0002
patching: float = -0.0003
compromised_should_be_good: float = -0.002
compromised_should_be_patching: float = -0.002
compromised_should_be_overwhelmed: float = -0.002
compromised: float = -0.002
overwhelmed_should_be_good: float = -0.002
overwhelmed_should_be_patching: float = -0.002
overwhelmed_should_be_compromised: float = -0.002
overwhelmed: float = -0.002
# Node File System State
good_should_be_repairing: int = 2
good_should_be_restoring: int = 2
good_should_be_corrupt: int = 5
good_should_be_destroyed: int = 10
repairing_should_be_good: int = -5
repairing_should_be_restoring: int = 2
repairing_should_be_corrupt: int = 2
repairing_should_be_destroyed: int = 0
repairing: int = -3
restoring_should_be_good: int = -10
restoring_should_be_repairing: int = -2
restoring_should_be_corrupt: int = 1
restoring_should_be_destroyed: int = 2
restoring: int = -6
corrupt_should_be_good: int = -10
corrupt_should_be_repairing: int = -10
corrupt_should_be_restoring: int = -10
corrupt_should_be_destroyed: int = 2
corrupt: int = -10
destroyed_should_be_good: int = -20
destroyed_should_be_repairing: int = -20
destroyed_should_be_restoring: int = -20
destroyed_should_be_corrupt: int = -20
destroyed: int = -20
scanning: int = -2
good_should_be_repairing: float = 0.0002
good_should_be_restoring: float = 0.0002
good_should_be_corrupt: float = 0.0005
good_should_be_destroyed: float = 0.001
repairing_should_be_good: float = -0.0005
repairing_should_be_restoring: float = 0.0002
repairing_should_be_corrupt: float = 0.0002
repairing_should_be_destroyed: float = 0.0000
repairing: float = -0.0003
restoring_should_be_good: float = -0.001
restoring_should_be_repairing: float = -0.0002
restoring_should_be_corrupt: float = 0.0001
restoring_should_be_destroyed: float = 0.0002
restoring: float = -0.0006
corrupt_should_be_good: float = -0.001
corrupt_should_be_repairing: float = -0.001
corrupt_should_be_restoring: float = -0.001
corrupt_should_be_destroyed: float = 0.0002
corrupt: float = -0.001
destroyed_should_be_good: float = -0.002
destroyed_should_be_repairing: float = -0.002
destroyed_should_be_restoring: float = -0.002
destroyed_should_be_corrupt: float = -0.002
destroyed: float = -0.002
scanning: float = -0.0002
# IER status
red_ier_running: int = -5
green_ier_blocked: int = -10
red_ier_running: float = -0.0005
green_ier_blocked: float = -0.001
# Patching / Reset durations
os_patching_duration: int = 5
@@ -178,6 +178,12 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system"
deterministic: bool = False
"If true, the training will be deterministic"
seed: Optional[int] = None
"The random number generator seed to be used while training the agent"
@classmethod
def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig:
"""

View File

@@ -142,10 +142,10 @@ class Primaite(Env):
self.step_info = {}
# Total reward
self.total_reward = 0
self.total_reward: float = 0
# Average reward
self.average_reward = 0
self.average_reward: float = 0
# Episode count
self.episode_count = 0
@@ -283,9 +283,9 @@ class Primaite(Env):
self._create_random_red_agent()
# Reset counters and totals
self.total_reward = 0
self.total_reward = 0.0
self.step_count = 0
self.average_reward = 0
self.average_reward = 0.0
# Update observations space and return
self.update_environent_obs()

View File

@@ -20,7 +20,7 @@ def calculate_reward_function(
red_iers,
step_count,
config_values,
):
) -> float:
"""
Compares the states of the initial and final nodes/links to get a reward.
@@ -33,7 +33,7 @@ def calculate_reward_function(
step_count: current step
config_values: Config values
"""
reward_value = 0
reward_value: float = 0.0
# For each node, compare hardware state, SoftwareState, service states
for node_key, final_node in final_nodes.items():
@@ -94,7 +94,7 @@ def calculate_reward_function(
return reward_value
def score_node_operating_state(final_node, initial_node, reference_node, config_values):
def score_node_operating_state(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the hardware state of a node.
@@ -104,7 +104,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
score: float = 0.0
final_node_operating_state = final_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
@@ -143,7 +143,7 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
return score
def score_node_os_state(final_node, initial_node, reference_node, config_values):
def score_node_os_state(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the Software State of a node.
@@ -153,7 +153,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
score: float = 0.0
final_node_os_state = final_node.software_state
reference_node_os_state = reference_node.software_state
@@ -194,7 +194,7 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
return score
def score_node_service_state(final_node, initial_node, reference_node, config_values):
def score_node_service_state(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the service state(s) of a node.
@@ -204,7 +204,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
reference_node: The node if there had been no red or blue effect
config_values: Config values
"""
score = 0
score: float = 0.0
final_node_services: Dict[str, Service] = final_node.services
reference_node_services: Dict[str, Service] = reference_node.services
@@ -266,7 +266,7 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
return score
def score_node_file_system(final_node, initial_node, reference_node, config_values):
def score_node_file_system(final_node, initial_node, reference_node, config_values) -> float:
"""
Calculates score relating to the file system state of a node.
@@ -275,7 +275,7 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
initial_node: The node before red and blue agents take effect
reference_node: The node if there had been no red or blue effect
"""
score = 0
score: float = 0.0
final_node_file_system_state = final_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual

View File

@@ -296,11 +296,17 @@ def apply_red_agent_node_pol(
pass
def is_red_ier_incoming(node, iers, node_pol_type):
"""
Checks if the RED IER is incoming.
def is_red_ier_incoming(node: NodeUnion, iers: Dict[str, IER], node_pol_type: NodePOLType) -> bool:
"""Checks if the RED IER is incoming.
TODO: Write more descriptive docstring with params and returns.
:param node: Destination node of the IER
:type node: NodeUnion
:param iers: Directory of IERs
:type iers: Dict[str,IER]
:param node_pol_type: Type of Pattern-Of-Life
:type node_pol_type: NodePOLType
:return: Whether the RED IER is incoming.
:rtype: bool
"""
node_id = node.node_id

View File

@@ -31,7 +31,7 @@ class Transaction(object):
"The observation space before any actions are taken"
self.obs_space_post = None
"The observation space after any actions are taken"
self.reward = None
self.reward: float = None
"The reward value"
self.action_space = None
"The action space invoked by the agent"

View File

@@ -0,0 +1,155 @@
# Training Config File
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: None
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: False
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0.0000
# Node Hardware State
off_should_be_on: -0.001
off_should_be_resetting: -0.0005
on_should_be_off: -0.0002
on_should_be_resetting: -0.0005
resetting_should_be_on: -0.0005
resetting_should_be_off: -0.0002
resetting: -0.0003
# Node Software or Service State
good_should_be_patching: 0.0002
good_should_be_compromised: 0.0005
good_should_be_overwhelmed: 0.0005
patching_should_be_good: -0.0005
patching_should_be_compromised: 0.0002
patching_should_be_overwhelmed: 0.0002
patching: -0.0003
compromised_should_be_good: -0.002
compromised_should_be_patching: -0.002
compromised_should_be_overwhelmed: -0.002
compromised: -0.002
overwhelmed_should_be_good: -0.002
overwhelmed_should_be_patching: -0.002
overwhelmed_should_be_compromised: -0.002
overwhelmed: -0.002
# Node File System State
good_should_be_repairing: 0.0002
good_should_be_restoring: 0.0002
good_should_be_corrupt: 0.0005
good_should_be_destroyed: 0.001
repairing_should_be_good: -0.0005
repairing_should_be_restoring: 0.0002
repairing_should_be_corrupt: 0.0002
repairing_should_be_destroyed: 0.0000
repairing: -0.0003
restoring_should_be_good: -0.001
restoring_should_be_repairing: -0.0002
restoring_should_be_corrupt: 0.0001
restoring_should_be_destroyed: 0.0002
restoring: -0.0006
corrupt_should_be_good: -0.001
corrupt_should_be_repairing: -0.001
corrupt_should_be_restoring: -0.001
corrupt_should_be_destroyed: 0.0002
corrupt: -0.001
destroyed_should_be_good: -0.002
destroyed_should_be_repairing: -0.002
destroyed_should_be_restoring: -0.002
destroyed_should_be_corrupt: -0.002
destroyed: -0.002
scanning: -0.0002
# IER status
red_ier_running: -0.0005
green_ier_blocked: -0.001
# Patching / Reset durations
os_patching_duration: 5 # The time taken to patch the OS
node_reset_duration: 5 # The time taken to reset a node (hardware)
service_patching_duration: 5 # The time taken to patch a service
file_system_repairing_limit: 5 # The time take to repair the file system
file_system_restoring_limit: 5 # The time take to restore the file system
file_system_scanning_limit: 5 # The time taken to scan the file system

View File

@@ -1,40 +1,94 @@
# Main Config File
# Training Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets which agent algorithm framework will be used.
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "CUSTOM" (Custom Agent)
agent_framework: SB3
# Sets which deep learning framework will be used (by RLlib ONLY).
# Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TF2
# Sets which Agent class will be used.
# Options are:
# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework)
# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework)
# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type)
# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type)
# "RANDOM" (primaite.agents.simple.RandomAgent)
# "DUMMY" (primaite.agents.simple.DummyAgent)
agent_identifier: PPO
# Sets whether Red Agent POL and IER is randomised.
# Options are:
# True
# False
random_red_agent: True
random_red_agent: False
# The (integer) seed to be used in random number generation
# Default is None (null)
seed: 67890
# Set whether the agent will be deterministic instead of stochastic
# Options are:
# True
# False
deterministic: True
# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC.
# Options are:
# "BASIC" (The current observation space only)
# "FULL" (Full environment view with actions taken and reward feedback)
hard_coded_agent_view: FULL
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# observation space
observation_space:
# flatten: true
components:
- name: NODE_LINK_TABLE
# - name: NODE_STATUSES
# - name: LINK_TRAFFIC_LEVELS
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 0
# Time delay (milliseconds) between steps for CUSTOM agents.
time_delay: 5
# Type of session to be run. Options are:
# "TRAIN" (Trains an agent)
# "EVAL" (Evaluates an agent)
# "TRAIN_EVAL" (Trains then evaluates an agent)
session_type: TRAIN_EVAL
# Environment config values
# The high value for the observation space
observation_space_high_value: 1000000000
# The Stable Baselines3 learn/eval output verbosity level:
# Options are:
# "NONE" (No Output)
# "INFO" (Info Messages (such as devices and wrappers used))
# "DEBUG" (All Messages)
sb3_output_verbose_level: NONE
# Reward values
# Generic
all_ok: 0

View File

@@ -58,7 +58,6 @@ class TempPrimaiteSession(PrimaiteSession):
def __exit__(self, type, value, tb):
shutil.rmtree(self.session_path)
shutil.rmtree(self.session_path.parent)
_LOGGER.debug(f"Deleted temp session directory: {self.session_path}")

View File

@@ -1,6 +1,7 @@
import tempfile
from datetime import datetime
from pathlib import Path
from uuid import uuid4
from primaite import getLogger
@@ -14,9 +15,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path:
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path
session_path = Path(tempfile.gettempdir()) / "primaite" / str(uuid4())
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created temp session directory: {session_path}")
return session_path

View File

@@ -33,6 +33,9 @@ def test_primaite_session(temp_primaite_session):
# Check that the network png file exists
assert (session_path / f"network_{session.timestamp_str}.png").exists()
# Check that the saved agent exists
assert session._agent_session._saved_agent_path.exists()
# Check that both the transactions and av reward csv files exist
for file in session.learning_path.iterdir():
if file.suffix == ".csv":

View File

@@ -0,0 +1,49 @@
import pytest as pytest
from primaite.config.lay_down_config import dos_very_basic_config_path
from tests import TEST_CONFIG_ROOT
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_seeded_learning(temp_primaite_session):
"""Test running seeded learning produces the same output when ran twice."""
expected_mean_reward_per_episode = {
1: -90.703125,
2: -91.15234375,
3: -87.5,
4: -92.2265625,
5: -94.6875,
6: -91.19140625,
7: -88.984375,
8: -88.3203125,
9: -112.79296875,
10: -100.01953125,
}
with temp_primaite_session as session:
assert session._training_config.seed == 67890, (
"Expected output is based upon a agent that was trained with " "seed 67890"
)
session.learn()
actual_mean_reward_per_episode = session.learn_av_reward_per_episode()
assert actual_mean_reward_per_episode == expected_mean_reward_per_episode
@pytest.mark.skip(reason="Inconsistent results. Needs someone with RL " "knowledge to investigate further.")
@pytest.mark.parametrize(
"temp_primaite_session",
[[TEST_CONFIG_ROOT / "ppo_seeded_training_config.yaml", dos_very_basic_config_path()]],
indirect=True,
)
def test_deterministic_evaluation(temp_primaite_session):
"""Test running deterministic evaluation gives same av eward per episode."""
with temp_primaite_session as session:
# do stuff
session.learn()
session.evaluate()
eval_mean_reward = session.eval_av_reward_per_episode_csv()
assert len(set(eval_mean_reward.values())) == 1