diff --git a/docs/source/how_to_guides/extensible_rewards.rst b/docs/source/how_to_guides/extensible_rewards.rst new file mode 100644 index 00000000..a01b9d8f --- /dev/null +++ b/docs/source/how_to_guides/extensible_rewards.rst @@ -0,0 +1,57 @@ +.. only:: comment + + © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK + +.. _about: + +Extensible Rewards +****************** +Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward +types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE +core repository). + +Changes to reward class structure. +================================== + +Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel). +Within the reward class there is a ConfigSchema class responsible for ensuring the config file data +is in the correct format. This also means there is little (if no) requirement for and `__init__` +method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`. +Each class requires an identifier string which is used by the ConfigSchema class to verify that it +hasn't previously been added to the registry. + +Inheriting from `BaseModel` removes the need for an `__init__` method but means that object +attributes need to be passed by keyword. + +To add a new reward class follow the example below. Note that the type attribute in the +`ConfigSchema` class should match the type used in the config file to define the reward. + +.. code-block:: Python + +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): + """Reward function component which rewards the agent for maintaining the integrity of a database file.""" + + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 + + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + type: str = "DATABASE_FILE_INTEGRITY" + node_hostname: str + folder_name: str + file_name: str + + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + pass + + + +Changes to YAML file. +===================== +.. code:: YAML + + There's no longer a need to provide a `dns_server` as an option in the simulation section + of the config file. diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 7c184770..a4c7c546 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -27,9 +27,10 @@ the structure: service_ref: web_server_database_client ``` """ -from abc import abstractmethod -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union +from pydantic import BaseModel from typing_extensions import Never from primaite import getLogger @@ -42,25 +43,32 @@ _LOGGER = getLogger(__name__) WhereType = Optional[Iterable[Union[str, int]]] -class AbstractReward: +class AbstractReward(BaseModel): """Base class for reward function components.""" - @abstractmethod - def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state. + config: "AbstractReward.ConfigSchema" - :param state: Current simulation state - :type state: Dict - :param last_action_response: Current agent history state - :type last_action_response: AgentHistoryItem state - :return: Reward value - :rtype: float - """ - return 0.0 + # def __init__(self, schema_name, **kwargs): + # super.__init__(self, **kwargs) + # # Create ConfigSchema class + # self.config_class = type(schema_name, (BaseModel, ABC), **kwargs) + # self.config = self.config_class() + + class ConfigSchema(BaseModel, ABC): + """Config schema for AbstractReward.""" + + type: str + + _registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {} + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate reward {identifier}") + cls._registry[identifier] = cls @classmethod - @abstractmethod - def from_config(cls, config: dict) -> "AbstractReward": + def from_config(cls, config: Dict) -> "AbstractReward": """Create a reward function component from a config dictionary. :param config: dict of options for the reward component's constructor @@ -68,11 +76,28 @@ class AbstractReward: :return: The reward component. :rtype: AbstractReward """ - return cls() + if config["type"] not in cls._registry: + raise ValueError(f"Invalid reward type {config['type']}") + reward_class = cls._registry[config["type"]] + reward_obj = reward_class(config=reward_class.ConfigSchema(**config)) + return reward_obj + + @abstractmethod + def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ + return 0.0 -class DummyReward(AbstractReward): - """Dummy reward function component which always returns 0.""" +class DummyReward(AbstractReward, identifier="DUMMY"): + """Dummy reward function component which always returns 0.0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -86,41 +111,21 @@ class DummyReward(AbstractReward): """ return 0.0 - @classmethod - def from_config(cls, config: dict) -> "DummyReward": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor. Should be empty. - :type config: dict - :return: The reward component. - :rtype: DummyReward - """ - return cls() - - -class DatabaseFileIntegrity(AbstractReward): +class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"): """Reward function component which rewards the agent for maintaining the integrity of a database file.""" - def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None: - """Initialise the reward component. + config: "DatabaseFileIntegrity.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - :param node_hostname: Hostname of the node which contains the database file. - :type node_hostname: str - :param folder_name: folder which contains the database file. - :type folder_name: str - :param file_name: name of the database file. - :type file_name: str - """ - self.location_in_state = [ - "network", - "nodes", - node_hostname, - "file_system", - "folders", - folder_name, - "files", - file_name, - ] + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for DatabaseFileIntegrity.""" + + type: str = "DATABASE_FILE_INTEGRITY" + node_hostname: str + folder_name: str + file_name: str def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -132,6 +137,17 @@ class DatabaseFileIntegrity(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "file_system", + "folders", + self.config.folder_name, + "files", + self.config.file_name, + ] + database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: _LOGGER.debug( @@ -148,44 +164,21 @@ class DatabaseFileIntegrity(AbstractReward): else: return 0 - @classmethod - def from_config(cls, config: Dict) -> "DatabaseFileIntegrity": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: DatabaseFileIntegrity - """ - node_hostname = config.get("node_hostname") - folder_name = config.get("folder_name") - file_name = config.get("file_name") - if not (node_hostname and folder_name and file_name): - msg = f"{cls.__name__} could not be initialised with parameters {config}" - _LOGGER.error(msg) - raise ValueError(msg) - - return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name) - - -class WebServer404Penalty(AbstractReward): +class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"): """Reward function component which penalises the agent when the web server returns a 404 error.""" - def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None: - """Initialise the reward component. + config: "WebServer404Penalty.ConfigSchema" + location_in_state: List[str] = [""] + reward: float = 0.0 - :param node_hostname: Hostname of the node which contains the web server service. - :type node_hostname: str - :param service_name: Name of the web server service. - :type service_name: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" - self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebServer404Penalty.""" + + type: str = "WEB_SERVER_404_PENALTY" + node_hostname: str + service_name: str + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. @@ -197,6 +190,13 @@ class WebServer404Penalty(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "services", + self.config.service_name, + ] web_service_state = access_from_nested_dict(state, self.location_in_state) # if webserver is no longer installed on the node, return 0 @@ -211,54 +211,27 @@ class WebServer404Penalty(AbstractReward): return 1.0 if status == 200 else -1.0 if status == 404 else 0.0 self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average - elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0 + elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0 self.reward = 0.0 else: # skip calculating if sticky and no new codes. instead, reuse last step's value pass return self.reward - @classmethod - def from_config(cls, config: Dict) -> "WebServer404Penalty": - """Create a reward function component from a config dictionary. - :param config: dict of options for the reward component's constructor - :type config: Dict - :return: The reward component. - :rtype: WebServer404Penalty - """ - node_hostname = config.get("node_hostname") - service_name = config.get("service_name") - if not (node_hostname and service_name): - msg = ( - f"{cls.__name__} could not be initialised from config because node_name and service_ref were not " - "found in reward config." - ) - _LOGGER.warning(msg) - raise ValueError(msg) - sticky = config.get("sticky", True) - - return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky) - - -class WebpageUnavailablePenalty(AbstractReward): +class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"): """Penalises the agent when the web browser fails to fetch a webpage.""" - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + config: "WebpageUnavailablePenalty.ConfigSchema" + reward: float = 0.0 + location_in_state: List[str] = [""] # Calculate in __init__()? - :param node_hostname: Hostname of the node which has the web browser. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for WebpageUnavailablePenalty.""" + + type: str = "WEBPAGE_UNAVAILABLE_PENALTY" + node_hostname: str = "" + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -274,6 +247,13 @@ class WebpageUnavailablePenalty(AbstractReward): :return: Reward value :rtype: float """ + self.location_in_state = [ + "network", + "nodes", + self.config.node_hostname, + "applications", + "WebBrowser", + ] web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE: @@ -283,14 +263,14 @@ class WebpageUnavailablePenalty(AbstractReward): request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.config.node_hostname, "application", "WebBrowser", "execute", ] # skip calculating if sticky and no new codes, reusing last step value - if not request_attempted and self.sticky: + if not request_attempted and self.config.sticky: return self.reward if last_action_response.response.status != "success": @@ -298,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward): elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]: _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", - f"{self._node} was not reported in the simulation state. Returning 0.0", + f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0", ) self.reward = 0.0 else: @@ -312,37 +292,19 @@ class WebpageUnavailablePenalty(AbstractReward): return self.reward - @classmethod - def from_config(cls, config: dict) -> AbstractReward: - """ - Build the reward component object from config. - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - - -class GreenAdminDatabaseUnreachablePenalty(AbstractReward): +class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"): """Penalises the agent when the green db clients fail to connect to the database.""" - def __init__(self, node_hostname: str, sticky: bool = True) -> None: - """ - Initialise the reward component. + config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema" + reward: float = 0.0 - :param node_hostname: Hostname of the node where the database client sits. - :type node_hostname: str - :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate - the reward if there were any responses this timestep. - :type sticky: bool - """ - self._node: str = node_hostname - self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - self.sticky: bool = sticky - self.reward: float = 0.0 - """Reward value calculated last time any responses were seen. Used for persisting sticky rewards.""" + class ConfigSchema(AbstractReward.ConfigSchema): + """ConfigSchema for GreenAdminDatabaseUnreachablePenalty.""" + + type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY" + node_hostname: str + sticky: bool = True def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """ @@ -362,7 +324,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): request_attempted = last_action_response.request == [ "network", "node", - self._node, + self.config.node_hostname, "application", "DatabaseClient", "execute", @@ -371,7 +333,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): if request_attempted: # if agent makes request, always recalculate fresh value last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status} self.reward = 1.0 if last_action_response.response.status == "success" else -1.0 - elif not self.sticky: # if no new request and not sticky, set reward to 0 + elif not self.config.sticky: # if no new request and not sticky, set reward to 0 last_action_response.reward_info = {"connection_attempt_status": "n/a"} self.reward = 0.0 else: # if no new request and sticky, reuse reward value from last step @@ -380,47 +342,30 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): return self.reward - @classmethod - def from_config(cls, config: Dict) -> AbstractReward: - """ - Build the reward component object from config. - :param config: Configuration dictionary. - :type config: Dict - """ - node_hostname = config.get("node_hostname") - sticky = config.get("sticky", True) - return cls(node_hostname=node_hostname, sticky=sticky) - - -class SharedReward(AbstractReward): +class SharedReward(AbstractReward, identifier="SHARED_REWARD"): """Adds another agent's reward to the overall reward.""" - def __init__(self, agent_name: Optional[str] = None) -> None: + config: "SharedReward.ConfigSchema" + + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for SharedReward.""" + + type: str = "SHARED_REWARD" + agent_name: str + + def default_callback(agent_name: str) -> Never: """ - Initialise the shared reward. + Default callback to prevent calling this reward until it's properly initialised. - The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work - correctly. - - :param agent_name: The name whose reward is an input - :type agent_name: Optional[str] + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. """ - self.agent_name = agent_name - """Agent whose reward to track.""" + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - def default_callback(agent_name: str) -> Never: - """ - Default callback to prevent calling this reward until it's properly initialised. - - SharedReward should not be used until the game layer replaces self.callback with a reference to the - function that retrieves the desired agent's reward. Therefore, we define this default callback that raises - an error. - """ - raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") - - self.callback: Callable[[str], float] = default_callback - """Method that retrieves an agent's current reward given the agent's name.""" + callback: Callable[[str], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Simply access the other agent's reward and return it. @@ -432,36 +377,20 @@ class SharedReward(AbstractReward): :return: Reward value :rtype: float """ - return self.callback(self.agent_name) - - @classmethod - def from_config(cls, config: Dict) -> "SharedReward": - """ - Build the SharedReward object from config. - - :param config: Configuration dictionary - :type config: Dict - """ - agent_name = config.get("agent_name") - return cls(agent_name=agent_name) + return self.callback(self.config.agent_name) -class ActionPenalty(AbstractReward): +class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"): """Apply a negative reward when taking any action except DONOTHING.""" - def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None: - """ - Initialise the reward. + config: "ActionPenalty.ConfigSchema" - Reward or penalise agents for doing nothing or taking actions. + class ConfigSchema(AbstractReward.ConfigSchema): + """Config schema for ActionPenalty.""" - :param action_penalty: Reward to give agents for taking any action except DONOTHING - :type action_penalty: float - :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action - :type do_nothing_penalty: float - """ - self.action_penalty = action_penalty - self.do_nothing_penalty = do_nothing_penalty + type: str = "ACTION_PENALTY" + action_penalty: float = -1.0 + do_nothing_penalty: float = 0.0 def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the penalty to be applied. @@ -474,32 +403,14 @@ class ActionPenalty(AbstractReward): :rtype: float """ if last_action_response.action == "DONOTHING": - return self.do_nothing_penalty + return self.config.do_nothing_penalty else: - return self.action_penalty - - @classmethod - def from_config(cls, config: Dict) -> "ActionPenalty": - """Build the ActionPenalty object from config.""" - action_penalty = config.get("action_penalty", -1.0) - do_nothing_penalty = config.get("do_nothing_penalty", 0.0) - return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty) + return self.config.action_penalty class RewardFunction: """Manages the reward function for the agent.""" - rew_class_identifiers: Dict[str, Type[AbstractReward]] = { - "DUMMY": DummyReward, - "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, - "WEB_SERVER_404_PENALTY": WebServer404Penalty, - "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, - "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, - "SHARED_REWARD": SharedReward, - "ACTION_PENALTY": ActionPenalty, - } - """List of reward class identifiers.""" - def __init__(self): """Initialise the reward function object.""" self.reward_components: List[Tuple[AbstractReward, float]] = [] @@ -534,7 +445,7 @@ class RewardFunction: @classmethod def from_config(cls, config: Dict) -> "RewardFunction": - """Create a reward function from a config dictionary. + """Create a reward function from a config dictionary and its related reward class. :param config: dict of options for the reward manager's constructor :type config: Dict @@ -545,8 +456,11 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] + # XXX: If options key is missing add key then add type key. + if "options" not in rew_component_cfg: + rew_component_cfg["options"] = {} + rew_component_cfg["options"]["type"] = rew_type weight = rew_component_cfg.get("weight", 1.0) - rew_class = cls.rew_class_identifiers[rew_type] - rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {})) + rew_instance = AbstractReward.from_config(rew_component_cfg["options"]) new.register_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index d8b28e94..5523c33c 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -376,7 +376,7 @@ class PrimaiteGame: if service_class is not None: _LOGGER.debug(f"installing {service_type} on node {new_node.hostname}") - new_node.software_manager.install(service_class) + new_node.software_manager.install(service_class, **service_cfg.get("options", {})) new_service = new_node.software_manager.software[service_class.__name__] # fixing duration for the service @@ -575,7 +575,7 @@ class PrimaiteGame: for comp, weight in agent.reward_function.reward_components: if isinstance(comp, SharedReward): comp: SharedReward - graph[name].add(comp.agent_name) + graph[name].add(comp.config.agent_name) # while constructing the graph, we might as well set up the reward sharing itself. comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb similarity index 99% rename from src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb rename to src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb index 6e6819fa..d2972fa9 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb @@ -1780,10 +1780,11 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] @@ -1804,7 +1805,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -1818,7 +1819,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index f573f251..117ea019 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -166,9 +166,10 @@ "from pathlib import Path\n", "from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n", "from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.file_system.file_system import FileSystem\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "from primaite.utils.validation.port import PORT_LOOKUP\n", + "\n", "\n", "# no applications exist yet so we will create our own.\n", "class MSPaint(Application, identifier=\"MSPaint\"):\n", @@ -182,7 +183,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=PORT_LOOKUP[\"HTTP\"], protocol = PROTOCOL_LOOKUP[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { @@ -249,7 +250,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 2d5b4772..4b620eee 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -532,12 +532,12 @@ }, "outputs": [], "source": [ - "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", - "from primaite.simulator.network.transmission.transport_layer import Port\n", "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", + "from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n", + "\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol["ICMP"],\n", + " protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" @@ -650,7 +650,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -664,7 +664,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 4b7286de..96130e16 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -318,6 +318,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"): """ if not self._can_perform_action(): return None + if self.server_ip_address is None: + self.sys_log.warning(f"{self.name}: Database server IP address not provided.") + return None connection_request_id = str(uuid4()) self._client_connection_requests[connection_request_id] = None diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index 48f1f383..d6617efa 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -16,7 +16,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP from primaite.utils.validation.port import Port, PORT_LOOKUP if TYPE_CHECKING: - from primaite.simulator.network.hardware.base import NetworkInterface + from primaite.simulator.network.hardware.base import NetworkInterface, Node from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.core.sys_log import SysLog diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index fed0f52d..799bb571 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -202,8 +202,6 @@ simulation: port_scan_p_of_success: 0.8 services: - type: DNSClient - options: - dns_server: 192.168.1.10 - type: DNSServer options: domain_mapping: diff --git a/tests/assets/configs/fixing_duration_one_item.yaml b/tests/assets/configs/fixing_duration_one_item.yaml index 57c1c4ce..025c0a6f 100644 --- a/tests/assets/configs/fixing_duration_one_item.yaml +++ b/tests/assets/configs/fixing_duration_one_item.yaml @@ -200,8 +200,6 @@ simulation: port_scan_p_of_success: 0.8 services: - type: DNSClient - options: - dns_server: 192.168.1.10 - type: DNSServer options: domain_mapping: @@ -232,8 +230,6 @@ simulation: server_password: arcd services: - type: DNSClient - options: - dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 diff --git a/tests/assets/configs/software_fixing_duration.yaml b/tests/assets/configs/software_fixing_duration.yaml index bb1254ed..3fd4e7a1 100644 --- a/tests/assets/configs/software_fixing_duration.yaml +++ b/tests/assets/configs/software_fixing_duration.yaml @@ -250,8 +250,6 @@ simulation: server_password: arcd services: - type: DNSClient - options: - dns_server: 192.168.1.10 links: - endpoint_a_hostname: switch_1 diff --git a/tests/integration_tests/extensions/test_extendable_config.py b/tests/integration_tests/extensions/test_extendable_config.py index 5515d900..8e73f929 100644 --- a/tests/integration_tests/extensions/test_extendable_config.py +++ b/tests/integration_tests/extensions/test_extendable_config.py @@ -5,6 +5,7 @@ from primaite.config.load import get_extended_config_path from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer +from tests import TEST_ASSETS_ROOT from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch @@ -13,11 +14,12 @@ from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch from tests.integration_tests.extensions.nodes.super_computer import SuperComputer from tests.integration_tests.extensions.services.extended_service import ExtendedService +CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml" + def test_extended_example_config(): """Test that the example config can be parsed properly.""" - config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml") - game = load_config(config_path) + game = load_config(CONFIG_PATH) network: Network = game.simulation.network assert len(network.nodes) == 10 # 10 nodes in example network diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index a674d864..045f510f 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -18,12 +18,14 @@ from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent -def test_WebpageUnavailablePenalty(game_and_agent): +def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for failing to fetch a website.""" # set up the scenario, configure the web browser to the correct url game, agent = game_and_agent agent: ControlledAgent - comp = WebpageUnavailablePenalty(node_hostname="client_1") + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = WebpageUnavailablePenalty(config=schema) + client_1 = game.simulation.network.get_node_by_hostname("client_1") browser: WebBrowser = client_1.software_manager.software.get("WebBrowser") browser.run() @@ -53,7 +55,7 @@ def test_WebpageUnavailablePenalty(game_and_agent): assert agent.reward_function.current_reward == -0.7 -def test_uc2_rewards(game_and_agent): +def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" game, agent = game_and_agent agent: ControlledAgent @@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent): ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2 ) - comp = GreenAdminDatabaseUnreachablePenalty("client_1") + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True) + comp = GreenAdminDatabaseUnreachablePenalty(config=schema) request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"] response = game.simulation.apply_request(request) @@ -139,15 +142,17 @@ def test_action_penalty_loads_from_config(): act_penalty_obj = comp[0] if act_penalty_obj is None: pytest.fail("Action penalty reward component was not added to the agent from config.") - assert act_penalty_obj.action_penalty == -0.75 - assert act_penalty_obj.do_nothing_penalty == 0.125 + assert act_penalty_obj.config.action_penalty == -0.75 + assert act_penalty_obj.config.do_nothing_penalty == 0.125 def test_action_penalty(): """Test that the action penalty is correctly applied when agent performs any action""" # Create an ActionPenalty Reward - Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + # Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + Penalty = ActionPenalty(config=schema) # Assert that penalty is applied if action isn't DONOTHING reward_value = Penalty.calculate( @@ -178,11 +183,12 @@ def test_action_penalty(): assert reward_value == 0.125 -def test_action_penalty_e2e(game_and_agent): +def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]): """Test that we get the right reward for doing actions to fetch a website.""" game, agent = game_and_agent agent: ControlledAgent - comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125) + schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125) + comp = ActionPenalty(config=schema) agent.reward_function.register_component(comp, 1.0) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py index 0e4bf1bb..346eb512 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_sticky_rewards.py @@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse class TestWebServer404PenaltySticky: def test_non_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=False) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=False, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebServer404Penalty("computer", "WebService", sticky=True) + schema = WebServer404Penalty.ConfigSchema( + node_hostname="computer", + service_name="WebService", + sticky=True, + ) + reward = WebServer404Penalty(config=schema) # no response codes yet, reward is 0 codes = [] @@ -67,7 +77,8 @@ class TestWebServer404PenaltySticky: class TestWebpageUnavailabilitySticky: def test_non_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=False) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -127,7 +138,8 @@ class TestWebpageUnavailabilitySticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = WebpageUnavailablePenalty("computer", sticky=True) + schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True) + reward = WebpageUnavailablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -188,7 +200,11 @@ class TestWebpageUnavailabilitySticky: class TestGreenAdminDatabaseUnreachableSticky: def test_non_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=False, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] @@ -214,7 +230,6 @@ class TestGreenAdminDatabaseUnreachableSticky: # agent did nothing, because reward is not sticky, it goes back to 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"] response = RequestResponse(status="success", data={}) - browser_history = [] state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}} last_action_response = AgentHistoryItem( timestep=0, action=action, parameters=params, request=request, response=response @@ -244,7 +259,11 @@ class TestGreenAdminDatabaseUnreachableSticky: assert reward.calculate(state, last_action_response) == -1.0 def test_sticky(self): - reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True) + schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema( + node_hostname="computer", + sticky=True, + ) + reward = GreenAdminDatabaseUnreachablePenalty(config=schema) # no response codes yet, reward is 0 action, params, request = "DO_NOTHING", {}, ["DONOTHING"]