901 - merged with changes made to dev
This commit is contained in:
@@ -25,6 +25,12 @@ steps:
|
||||
versionSpec: '$(python.version)'
|
||||
displayName: 'Use Python $(python.version)'
|
||||
|
||||
- script: |
|
||||
python -m pip install pre-commit
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
displayName: 'Run pre-commits'
|
||||
|
||||
- script: |
|
||||
python -m pip install --upgrade pip==23.0.1
|
||||
pip install wheel==0.38.4 --upgrade
|
||||
|
||||
@@ -31,7 +31,7 @@ def _get_primaite_config():
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pkg_resources
|
||||
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
import re
|
||||
|
||||
from primaite import LOG_PATH
|
||||
|
||||
if os.path.isfile(LOG_PATH):
|
||||
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -45,7 +45,7 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
|
||||
from primaite.transactions.transaction import Transaction
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
# _LOGGER.setLevel(logging.INFO)
|
||||
_LOGGER.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Primaite(Env):
|
||||
@@ -107,6 +107,7 @@ class Primaite(Env):
|
||||
|
||||
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
||||
self.green_iers: Dict[str, IER] = {}
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
@@ -202,7 +203,6 @@ class Primaite(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
datetime.now() # current date and time
|
||||
|
||||
file_path = session_path / f"network_{timestamp_str}.png"
|
||||
plt.savefig(file_path, format="PNG")
|
||||
@@ -218,10 +218,22 @@ class Primaite(Env):
|
||||
# Define Action Space - depends on action space type (Node or ACL)
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
_LOGGER.info("Action space type NODE selected")
|
||||
# Terms (for node action space):
|
||||
# [0, num nodes] - node ID (0 = nothing, node ID)
|
||||
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
|
||||
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
|
||||
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
|
||||
self.action_dict = self.create_node_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
_LOGGER.info("Action space type ACL selected")
|
||||
# Terms (for ACL action space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
|
||||
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
|
||||
self.action_dict = self.create_acl_action_dict()
|
||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||
elif self.training_config.action_type == ActionType.ANY:
|
||||
@@ -304,6 +316,9 @@ class Primaite(Env):
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
for link in self.links_reference.values():
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
@@ -341,7 +356,7 @@ class Primaite(Env):
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
@@ -368,6 +383,7 @@ class Primaite(Env):
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
@@ -426,6 +442,7 @@ class Primaite(Env):
|
||||
_action: The action space from the agent
|
||||
"""
|
||||
# At the moment, actions are only affecting nodes
|
||||
|
||||
if self.training_config.action_type == ActionType.NODE:
|
||||
self.apply_actions_to_nodes(_action)
|
||||
elif self.training_config.action_type == ActionType.ACL:
|
||||
@@ -860,6 +877,17 @@ class Primaite(Env):
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
self.green_iers_reference[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
|
||||
@@ -2,17 +2,21 @@
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
@@ -68,14 +72,36 @@ def calculate_reward_function(
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
_LOGGER.debug(
|
||||
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||
)
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
|
||||
return reward_value
|
||||
|
||||
@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -349,5 +348,3 @@ if __name__ == "__main__":
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class Node:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self):
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
@@ -64,14 +65,14 @@ class Node:
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self):
|
||||
"""Updates the booting count"""
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self):
|
||||
"""Updates the shutdown count"""
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
self.shutting_down_count = 0
|
||||
|
||||
@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self):
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state =SoftwareState.GOOD
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
@@ -17,7 +17,6 @@ def start_jupyter_session():
|
||||
|
||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("jupyter") is not None:
|
||||
jupyter_cmd = "python3 -m jupyter lab"
|
||||
if sys.platform == "win32":
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Used to test Active Node functions."""
|
||||
import pytest
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority
|
||||
from primaite.common.enums import (
|
||||
FileSystemState,
|
||||
HardwareState,
|
||||
NodeType,
|
||||
Priority,
|
||||
SoftwareState,
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
@@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"starting_operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.RESETTING, HardwareState.ON)
|
||||
],
|
||||
[(HardwareState.RESETTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node resets correctly.
|
||||
"""
|
||||
"""Tests that a node resets correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id = "0",
|
||||
name = "node",
|
||||
node_type = NodeType.COMPUTER,
|
||||
priority = Priority.P1,
|
||||
hardware_state = starting_operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.COMPROMISED,
|
||||
file_system_state = FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig()
|
||||
node_id="0",
|
||||
name="node",
|
||||
node_type=NodeType.COMPUTER,
|
||||
priority=Priority.P1,
|
||||
hardware_state=starting_operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.COMPROMISED,
|
||||
file_system_state=FileSystemState.CORRUPT,
|
||||
config_values=TrainingConfig(),
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
@@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat
|
||||
assert active_node.file_system_state_actual == FileSystemState.GOOD
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.BOOTING, HardwareState.ON)
|
||||
],
|
||||
[(HardwareState.BOOTING, HardwareState.ON)],
|
||||
)
|
||||
def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node boots correctly.
|
||||
"""
|
||||
"""Tests that a node boots correctly."""
|
||||
service_node = ServiceNode(
|
||||
node_id = 0,
|
||||
name = "node",
|
||||
node_type = "COMPUTER",
|
||||
priority = "1",
|
||||
hardware_state = operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.GOOD,
|
||||
file_system_state = "GOOD",
|
||||
config_values = 1,
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
service_attributes = Service(
|
||||
name = "node",
|
||||
port = "80",
|
||||
software_state = SoftwareState.COMPROMISED
|
||||
)
|
||||
service_node.add_service(
|
||||
service_attributes
|
||||
name="node", port="80", software_state=SoftwareState.COMPROMISED
|
||||
)
|
||||
service_node.add_service(service_attributes)
|
||||
|
||||
for x in range(5):
|
||||
service_node.update_booting_status()
|
||||
@@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
|
||||
assert service_attributes.software_state == SoftwareState.GOOD
|
||||
assert service_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operating_state, expected_operating_state",
|
||||
[
|
||||
(HardwareState.SHUTTING_DOWN, HardwareState.OFF)
|
||||
],
|
||||
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
|
||||
)
|
||||
def test_node_shutdown_correctly(operating_state, expected_operating_state):
|
||||
"""
|
||||
Tests that a node shutdown correctly.
|
||||
"""
|
||||
"""Tests that a node shutdown correctly."""
|
||||
active_node = ActiveNode(
|
||||
node_id = 0,
|
||||
name = "node",
|
||||
node_type = "COMPUTER",
|
||||
priority = "1",
|
||||
hardware_state = operating_state,
|
||||
ip_address = "192.168.0.1",
|
||||
software_state = SoftwareState.GOOD,
|
||||
file_system_state = "GOOD",
|
||||
config_values = 1,
|
||||
node_id=0,
|
||||
name="node",
|
||||
node_type="COMPUTER",
|
||||
priority="1",
|
||||
hardware_state=operating_state,
|
||||
ip_address="192.168.0.1",
|
||||
software_state=SoftwareState.GOOD,
|
||||
file_system_state="GOOD",
|
||||
config_values=1,
|
||||
)
|
||||
|
||||
for x in range(5):
|
||||
active_node.update_shutdown_status()
|
||||
|
||||
assert active_node.hardware_state == expected_operating_state
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user