diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 8a79d068..cc559b4d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -23,7 +23,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter -from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.network.nmne import store_nmne_config, NmneData from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -113,6 +113,9 @@ class PrimaiteGame: self._reward_calculation_order: List[str] = [name for name in self.agents] """Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards.""" + self.nmne_config: NmneData = None + """ Config data from Number of Malicious Network Events.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -496,10 +499,11 @@ class PrimaiteGame: # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. game.setup_reward_sharing() - # Set the NMNE capture config - set_nmne_config(network_config.get("nmne_config", {})) game.update_agents(game.get_sim_state()) + # Set the NMNE capture config + game.nmne_config = store_nmne_config(network_config.get("nmne_config", {})) + return game def setup_reward_sharing(self): @@ -539,3 +543,6 @@ class PrimaiteGame: # sort the agents so the rewards that depend on other rewards are always evaluated later self._reward_calculation_order = topological_sort(graph) + + def get_nmne_config(self) -> NmneData: + return self.nmne_config diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index d6f1763f..947f27ac 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -4,7 +4,7 @@ from typing import Dict, List @dataclass -class nmne_data: +class NmneData: """Store all the information to perform NMNE operations.""" capture_nmne: bool = True @@ -23,10 +23,9 @@ class nmne_data: """Captures should be filtered and categorised based on specific keywords.""" -def set_nmne_config(nmne_config: Dict) -> nmne_data: +def store_nmne_config(nmne_config: Dict) -> NmneData: """ - Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided - dictionary. + Store configuration for capturing Malicious Network Events (MNEs). This function updates global settings related to NMNE capture, including whether to capture NMNEs and what keywords to use for identifying NMNEs. @@ -41,7 +40,7 @@ def set_nmne_config(nmne_config: Dict) -> nmne_data: "nmne_capture_keywords" (list of strings) to specify keywords for NMNE identification. :rvar dataclass with data read from config file. """ - nmne_capture_keywords = [] + nmne_capture_keywords: List[str] = [] # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect capture_nmne = nmne_config.get("capture_nmne", False) if not isinstance(capture_nmne, bool): @@ -52,4 +51,4 @@ def set_nmne_config(nmne_config: Dict) -> nmne_data: if not isinstance(nmne_capture_keywords, list): nmne_capture_keywords = [] # Reset to empty list if the provided value is not a list - return nmne_data(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords) + return NmneData(capture_nmne=capture_nmne, nmne_capture_keywords=nmne_capture_keywords)