Merge branch 'dev' into feature/1522-Random-Red-Agent-Behaviour

This commit is contained in:
Czar Echavez
2023-06-29 14:17:41 +01:00
12 changed files with 121 additions and 79 deletions

View File

@@ -25,6 +25,12 @@ steps:
versionSpec: '$(python.version)'
displayName: 'Use Python $(python.version)'
- script: |
python -m pip install pre-commit
pre-commit install
pre-commit run --all-files
displayName: 'Run pre-commits'
- script: |
python -m pip install --upgrade pip==23.0.1
pip install wheel==0.38.4 --upgrade

View File

@@ -31,7 +31,7 @@ def _get_primaite_config():
"INFO": logging.INFO,
"WARN": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
return primaite_config

View File

@@ -3,8 +3,8 @@
import logging
import os
import shutil
from pathlib import Path
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
:param last_n: The number of lines to print. Default value is 10.
"""
import re
from primaite import LOG_PATH
if os.path.isfile(LOG_PATH):
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
print(re.sub(r"\n*", "", line))
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
@app.command()

View File

@@ -107,6 +107,7 @@ class Primaite(Env):
# Create a dictionary to hold all the green IERs (this will come from an external source)
self.green_iers: Dict[str, IER] = {}
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
self.node_pol = {}
@@ -196,7 +197,6 @@ class Primaite(Env):
try:
plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True)
now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
@@ -314,6 +314,9 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
link_value.clear_traffic()
for link in self.links_reference.values():
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
@@ -351,7 +354,7 @@ class Primaite(Env):
self.network_reference,
self.nodes_reference,
self.links_reference,
self.green_iers,
self.green_iers_reference,
self.acl,
self.step_count,
) # Network PoL
@@ -377,12 +380,12 @@ class Primaite(Env):
self.nodes_post_pol,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
self.green_iers_reference,
self.red_iers,
self.step_count,
self.training_config,
)
print(f" Step {self.step_count} Reward: {str(reward)}")
# print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -869,6 +872,17 @@ class Primaite(Env):
ier_destination,
ier_mission_criticality,
)
self.green_iers_reference[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source,
ier_destination,
ier_mission_criticality,
)
def create_red_ier(self, item):
"""

View File

@@ -2,17 +2,21 @@
"""Implements reward function."""
from typing import Dict
from primaite import getLogger
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
@@ -68,14 +72,36 @@ def calculate_reward_function(
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running():
reward_value += (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality()
)
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value

View File

@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = getLogger(__name__)
@@ -349,5 +348,3 @@ if __name__ == "__main__":
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -46,6 +46,7 @@ class Node:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
@@ -64,14 +65,14 @@ class Node:
self.hardware_state = HardwareState.ON
def update_booting_status(self):
"""Updates the booting count"""
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
"""Updates the shutdown count"""
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:
self.shutting_down_count = 0

View File

@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
service_value.reduce_patching_count()
def update_resetting_status(self):
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0:
for service in self.services.values():
service.software_state =SoftwareState.GOOD
service.software_state = SoftwareState.GOOD

View File

@@ -17,7 +17,6 @@ def start_jupyter_session():
.. todo:: Figure out how to get this working for Linux and MacOS too.
"""
if importlib.util.find_spec("jupyter") is not None:
jupyter_cmd = "python3 -m jupyter lab"
if sys.platform == "win32":

View File

@@ -27,7 +27,8 @@ def env(request):
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
@@ -43,7 +44,8 @@ def test_default_obs_space(env: Primaite):
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
@@ -140,7 +142,8 @@ class TestNodeLinkTable:
@pytest.mark.env_config_paths(
dict(
training_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_STATUSES.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)

View File

@@ -1,7 +1,13 @@
"""Used to test Active Node functions."""
import pytest
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState, NodeType, Priority
from primaite.common.enums import (
FileSystemState,
HardwareState,
NodeType,
Priority,
SoftwareState,
)
from primaite.common.service import Service
from primaite.config.training_config import TrainingConfig
from primaite.nodes.active_node import ActiveNode
@@ -10,24 +16,20 @@ from primaite.nodes.service_node import ServiceNode
@pytest.mark.parametrize(
"starting_operating_state, expected_operating_state",
[
(HardwareState.RESETTING, HardwareState.ON)
],
[(HardwareState.RESETTING, HardwareState.ON)],
)
def test_node_resets_correctly(starting_operating_state, expected_operating_state):
"""
Tests that a node resets correctly.
"""
"""Tests that a node resets correctly."""
active_node = ActiveNode(
node_id = "0",
name = "node",
node_type = NodeType.COMPUTER,
priority = Priority.P1,
hardware_state = starting_operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.COMPROMISED,
file_system_state = FileSystemState.CORRUPT,
config_values=TrainingConfig()
node_id="0",
name="node",
node_type=NodeType.COMPUTER,
priority=Priority.P1,
hardware_state=starting_operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.COMPROMISED,
file_system_state=FileSystemState.CORRUPT,
config_values=TrainingConfig(),
)
for x in range(5):
@@ -37,35 +39,28 @@ def test_node_resets_correctly(starting_operating_state, expected_operating_stat
assert active_node.file_system_state_actual == FileSystemState.GOOD
assert active_node.hardware_state == expected_operating_state
@pytest.mark.parametrize(
"operating_state, expected_operating_state",
[
(HardwareState.BOOTING, HardwareState.ON)
],
[(HardwareState.BOOTING, HardwareState.ON)],
)
def test_node_boots_correctly(operating_state, expected_operating_state):
"""
Tests that a node boots correctly.
"""
"""Tests that a node boots correctly."""
service_node = ServiceNode(
node_id = 0,
name = "node",
node_type = "COMPUTER",
priority = "1",
hardware_state = operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.GOOD,
file_system_state = "GOOD",
config_values = 1,
node_id=0,
name="node",
node_type="COMPUTER",
priority="1",
hardware_state=operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.GOOD,
file_system_state="GOOD",
config_values=1,
)
service_attributes = Service(
name = "node",
port = "80",
software_state = SoftwareState.COMPROMISED
)
service_node.add_service(
service_attributes
name="node", port="80", software_state=SoftwareState.COMPROMISED
)
service_node.add_service(service_attributes)
for x in range(5):
service_node.update_booting_status()
@@ -73,31 +68,26 @@ def test_node_boots_correctly(operating_state, expected_operating_state):
assert service_attributes.software_state == SoftwareState.GOOD
assert service_node.hardware_state == expected_operating_state
@pytest.mark.parametrize(
"operating_state, expected_operating_state",
[
(HardwareState.SHUTTING_DOWN, HardwareState.OFF)
],
[(HardwareState.SHUTTING_DOWN, HardwareState.OFF)],
)
def test_node_shutdown_correctly(operating_state, expected_operating_state):
"""
Tests that a node shutdown correctly.
"""
"""Tests that a node shutdown correctly."""
active_node = ActiveNode(
node_id = 0,
name = "node",
node_type = "COMPUTER",
priority = "1",
hardware_state = operating_state,
ip_address = "192.168.0.1",
software_state = SoftwareState.GOOD,
file_system_state = "GOOD",
config_values = 1,
node_id=0,
name="node",
node_type="COMPUTER",
priority="1",
hardware_state=operating_state,
ip_address="192.168.0.1",
software_state=SoftwareState.GOOD,
file_system_state="GOOD",
config_values=1,
)
for x in range(5):
active_node.update_shutdown_status()
assert active_node.hardware_state == expected_operating_state

View File

@@ -48,7 +48,8 @@ def test_single_action_space_is_valid():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
run_generic_set_actions(env)
@@ -77,8 +78,10 @@ def test_single_action_space_is_valid():
def test_agent_is_executing_actions_from_both_spaces():
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
env = _get_primaite_env_from_config(
training_config_path=TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml",
training_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env)