Merge branch 'dev' into feature/1522-Random-Red-Agent-Behaviour
This commit is contained in:
@@ -31,7 +31,7 @@ def _get_primaite_config():
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["log_level"]]
|
||||
return primaite_config
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pkg_resources
|
||||
@@ -44,6 +44,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
:param last_n: The number of lines to print. Default value is 10.
|
||||
"""
|
||||
import re
|
||||
|
||||
from primaite import LOG_PATH
|
||||
|
||||
if os.path.isfile(LOG_PATH):
|
||||
@@ -53,7 +54,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]):
|
||||
print(re.sub(r"\n*", "", line))
|
||||
|
||||
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -107,6 +107,7 @@ class Primaite(Env):
|
||||
|
||||
# Create a dictionary to hold all the green IERs (this will come from an external source)
|
||||
self.green_iers: Dict[str, IER] = {}
|
||||
self.green_iers_reference: Dict[str, IER] = {}
|
||||
|
||||
# Create a dictionary to hold all the node PoLs (this will come from an external source)
|
||||
self.node_pol = {}
|
||||
@@ -196,7 +197,6 @@ class Primaite(Env):
|
||||
try:
|
||||
plt.tight_layout()
|
||||
nx.draw_networkx(self.network, with_labels=True)
|
||||
now = datetime.now() # current date and time
|
||||
|
||||
file_path = session_path / f"network_{timestamp_str}.png"
|
||||
plt.savefig(file_path, format="PNG")
|
||||
@@ -314,6 +314,9 @@ class Primaite(Env):
|
||||
for link_key, link_value in self.links.items():
|
||||
link_value.clear_traffic()
|
||||
|
||||
for link in self.links_reference.values():
|
||||
link.clear_traffic()
|
||||
|
||||
# Create a Transaction (metric) object for this step
|
||||
transaction = Transaction(
|
||||
datetime.now(), self.agent_identifier, self.episode_count, self.step_count
|
||||
@@ -351,7 +354,7 @@ class Primaite(Env):
|
||||
self.network_reference,
|
||||
self.nodes_reference,
|
||||
self.links_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.acl,
|
||||
self.step_count,
|
||||
) # Network PoL
|
||||
@@ -377,12 +380,12 @@ class Primaite(Env):
|
||||
self.nodes_post_pol,
|
||||
self.nodes_post_red,
|
||||
self.nodes_reference,
|
||||
self.green_iers,
|
||||
self.green_iers_reference,
|
||||
self.red_iers,
|
||||
self.step_count,
|
||||
self.training_config,
|
||||
)
|
||||
print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
# print(f" Step {self.step_count} Reward: {str(reward)}")
|
||||
self.total_reward += reward
|
||||
if self.step_count == self.episode_steps:
|
||||
self.average_reward = self.total_reward / self.step_count
|
||||
@@ -869,6 +872,17 @@ class Primaite(Env):
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
self.green_iers_reference[ier_id] = IER(
|
||||
ier_id,
|
||||
ier_start_step,
|
||||
ier_end_step,
|
||||
ier_load,
|
||||
ier_protocol,
|
||||
ier_port,
|
||||
ier_source,
|
||||
ier_destination,
|
||||
ier_mission_criticality,
|
||||
)
|
||||
|
||||
def create_red_ier(self, item):
|
||||
"""
|
||||
|
||||
@@ -2,17 +2,21 @@
|
||||
"""Implements reward function."""
|
||||
from typing import Dict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.common.service import Service
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_reward_function(
|
||||
initial_nodes,
|
||||
final_nodes,
|
||||
reference_nodes,
|
||||
green_iers,
|
||||
green_iers_reference,
|
||||
red_iers,
|
||||
step_count,
|
||||
config_values,
|
||||
@@ -68,14 +72,36 @@ def calculate_reward_function(
|
||||
reward_value += config_values.red_ier_running
|
||||
|
||||
# Go through each green IER - penalise if it's not running (weighted)
|
||||
# but only if it's supposed to be running (it's running in reference)
|
||||
for ier_key, ier_value in green_iers.items():
|
||||
reference_ier = green_iers_reference[ier_key]
|
||||
start_step = ier_value.get_start_step()
|
||||
stop_step = ier_value.get_end_step()
|
||||
if step_count >= start_step and step_count <= stop_step:
|
||||
if not ier_value.get_is_running():
|
||||
reward_value += (
|
||||
config_values.green_ier_blocked
|
||||
* ier_value.get_mission_criticality()
|
||||
reference_blocked = not reference_ier.get_is_running()
|
||||
live_blocked = not ier_value.get_is_running()
|
||||
ier_reward = (
|
||||
config_values.green_ier_blocked * ier_value.get_mission_criticality()
|
||||
)
|
||||
|
||||
if live_blocked and not reference_blocked:
|
||||
_LOGGER.debug(
|
||||
f"Applying reward of {ier_reward} because IER {ier_key} is blocked"
|
||||
)
|
||||
reward_value += ier_reward
|
||||
elif live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference and live environments. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
elif not live_blocked and reference_blocked:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
f"IER {ier_key} is blocked in the reference env but not in the live one. "
|
||||
f"Penalty of {ier_reward} was NOT applied."
|
||||
)
|
||||
)
|
||||
|
||||
return reward_value
|
||||
|
||||
@@ -22,8 +22,7 @@ from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
from primaite import SESSIONS_DIR, getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.transactions.transactions_to_file import \
|
||||
write_transaction_to_file
|
||||
from primaite.transactions.transactions_to_file import write_transaction_to_file
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -349,5 +348,3 @@ if __name__ == "__main__":
|
||||
"Please provide a lay down config file using the --ldc " "argument"
|
||||
)
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class Node:
|
||||
"""Sets the node state to ON."""
|
||||
self.hardware_state = HardwareState.BOOTING
|
||||
self.booting_count = self.config_values.node_booting_duration
|
||||
|
||||
def turn_off(self):
|
||||
"""Sets the node state to OFF."""
|
||||
self.hardware_state = HardwareState.OFF
|
||||
@@ -64,14 +65,14 @@ class Node:
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_booting_status(self):
|
||||
"""Updates the booting count"""
|
||||
"""Updates the booting count."""
|
||||
self.booting_count -= 1
|
||||
if self.booting_count <= 0:
|
||||
self.booting_count = 0
|
||||
self.hardware_state = HardwareState.ON
|
||||
|
||||
def update_shutdown_status(self):
|
||||
"""Updates the shutdown count"""
|
||||
"""Updates the shutdown count."""
|
||||
self.shutting_down_count -= 1
|
||||
if self.shutting_down_count <= 0:
|
||||
self.shutting_down_count = 0
|
||||
|
||||
@@ -190,13 +190,15 @@ class ServiceNode(ActiveNode):
|
||||
service_value.reduce_patching_count()
|
||||
|
||||
def update_resetting_status(self):
|
||||
"""Update resetting counter and set software state if it reached 0."""
|
||||
super().update_resetting_status()
|
||||
if self.resetting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
def update_booting_status(self):
|
||||
"""Update booting counter and set software to good if it reached 0."""
|
||||
super().update_booting_status()
|
||||
if self.booting_count <= 0:
|
||||
for service in self.services.values():
|
||||
service.software_state =SoftwareState.GOOD
|
||||
service.software_state = SoftwareState.GOOD
|
||||
|
||||
@@ -17,7 +17,6 @@ def start_jupyter_session():
|
||||
|
||||
.. todo:: Figure out how to get this working for Linux and MacOS too.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("jupyter") is not None:
|
||||
jupyter_cmd = "python3 -m jupyter lab"
|
||||
if sys.platform == "win32":
|
||||
|
||||
Reference in New Issue
Block a user