Merge branch 'dev' into feature/1522-Random-Red-Agent-Behaviour

This commit is contained in:
Czar Echavez
2023-06-29 14:17:41 +01:00
12 changed files with 121 additions and 79 deletions

View File

@@ -31,7 +31,7 @@ def _get_primaite_config():
"INFO": logging.INFO,
"WARN": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
return primaite_config

View File

@@ -3,8 +3,8 @@
import logging
import os
import shutil
from pathlib import Path
from enum import Enum
from pathlib import Path
from typing import Optional
import pkg_resources
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
:param last_n: The number of lines to print. Default value is 10.
"""
import re
from primaite import LOG_PATH
if os.path.isfile(LOG_PATH):
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
print(re.sub(r"\n*", "", line))
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
@app.command()

View File

@@ -107,6 +107,7 @@ class Primaite(Env):
# Create a dictionary to hold all the green IERs (this will come from an external source)
self.green_iers: Dict[str, IER] = {}
self.green_iers_reference: Dict[str, IER] = {}
# Create a dictionary to hold all the node PoLs (this will come from an external source)
self.node_pol = {}
@@ -196,7 +197,6 @@ class Primaite(Env):
try:
plt.tight_layout()
nx.draw_networkx(self.network, with_labels=True)
now = datetime.now() # current date and time
file_path = session_path / f"network_{timestamp_str}.png"
plt.savefig(file_path, format="PNG")
@@ -314,6 +314,9 @@ class Primaite(Env):
for link_key, link_value in self.links.items():
link_value.clear_traffic()
for link in self.links_reference.values():
link.clear_traffic()
# Create a Transaction (metric) object for this step
transaction = Transaction(
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
@@ -351,7 +354,7 @@ class Primaite(Env):
self.network_reference,
self.nodes_reference,
self.links_reference,
self.green_iers,
self.green_iers_reference,
self.acl,
self.step_count,
) # Network PoL
@@ -377,12 +380,12 @@ class Primaite(Env):
self.nodes_post_pol,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
self.green_iers_reference,
self.red_iers,
self.step_count,
self.training_config,
)
print(f" Step {self.step_count} Reward: {str(reward)}")
# print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -869,6 +872,17 @@ class Primaite(Env):
ier_destination,
ier_mission_criticality,
)
self.green_iers_reference[ier_id] = IER(
ier_id,
ier_start_step,
ier_end_step,
ier_load,
ier_protocol,
ier_port,
ier_source,
ier_destination,
ier_mission_criticality,
)
def create_red_ier(self, item):
"""

View File

@@ -2,17 +2,21 @@
"""Implements reward function."""
from typing import Dict
from primaite import getLogger
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.common.service import Service
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
_LOGGER = getLogger(__name__)
def calculate_reward_function(
initial_nodes,
final_nodes,
reference_nodes,
green_iers,
green_iers_reference,
red_iers,
step_count,
config_values,
@@ -68,14 +72,36 @@ def calculate_reward_function(
reward_value += config_values.red_ier_running
# Go through each green IER - penalise if it's not running (weighted)
# but only if it's supposed to be running (it's running in reference)
for ier_key, ier_value in green_iers.items():
reference_ier = green_iers_reference[ier_key]
start_step = ier_value.get_start_step()
stop_step = ier_value.get_end_step()
if step_count >= start_step and step_count <= stop_step:
if not ier_value.get_is_running():
reward_value += (
config_values.green_ier_blocked
* ier_value.get_mission_criticality()
reference_blocked = not reference_ier.get_is_running()
live_blocked = not ier_value.get_is_running()
ier_reward = (
config_values.green_ier_blocked * ier_value.get_mission_criticality()
)
if live_blocked and not reference_blocked:
_LOGGER.debug(
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
)
reward_value += ier_reward
elif live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference and live environments. "
f"Penalty of {ier_reward} was NOT applied."
)
)
elif not live_blocked and reference_blocked:
_LOGGER.debug(
(
f"IER {ier_key} is blocked in the reference env but not in the live one. "
f"Penalty of {ier_reward} was NOT applied."
)
)
return reward_value

View File

@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
from primaite import SESSIONS_DIR, getLogger
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
from primaite.transactions.transactions_to_file import \
write_transaction_to_file
from primaite.transactions.transactions_to_file import write_transaction_to_file
_LOGGER = getLogger(__name__)
@@ -349,5 +348,3 @@ if __name__ == "__main__":
"Please provide a lay down config file using the --ldc " "argument"
)
run(training_config_path=args.tc, lay_down_config_path=args.ldc)

View File

@@ -46,6 +46,7 @@ class Node:
"""Sets the node state to ON."""
self.hardware_state = HardwareState.BOOTING
self.booting_count = self.config_values.node_booting_duration
def turn_off(self):
"""Sets the node state to OFF."""
self.hardware_state = HardwareState.OFF
@@ -64,14 +65,14 @@ class Node:
self.hardware_state = HardwareState.ON
def update_booting_status(self):
"""Updates the booting count"""
"""Updates the booting count."""
self.booting_count -= 1
if self.booting_count <= 0:
self.booting_count = 0
self.hardware_state = HardwareState.ON
def update_shutdown_status(self):
"""Updates the shutdown count"""
"""Updates the shutdown count."""
self.shutting_down_count -= 1
if self.shutting_down_count <= 0:
self.shutting_down_count = 0

View File

@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
service_value.reduce_patching_count()
def update_resetting_status(self):
"""Update resetting counter and set software state if it reached 0."""
super().update_resetting_status()
if self.resetting_count <= 0:
for service in self.services.values():
service.software_state = SoftwareState.GOOD
def update_booting_status(self):
"""Update booting counter and set software to good if it reached 0."""
super().update_booting_status()
if self.booting_count <= 0:
for service in self.services.values():
service.software_state =SoftwareState.GOOD
service.software_state = SoftwareState.GOOD

View File

@@ -17,7 +17,6 @@ def start_jupyter_session():
.. todo:: Figure out how to get this working for Linux and MacOS too.
"""
if importlib.util.find_spec("jupyter") is not None:
jupyter_cmd = "python3 -m jupyter lab"
if sys.platform == "win32":