Merged PR 567: Extensible Rewards

## Summary
*Replace this text with an explanation of what the changes are and how you implemented them. Can this impact any other parts of the codebase that we should keep in mind?*

## Test process
*How have you tested this (if applicable)?*

## Checklist
- [X] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2913
This commit is contained in:
Nick Todd
2025-01-03 11:54:20 +00:00
committed by Marek Wolan
14 changed files with 273 additions and 279 deletions

View File

@@ -0,0 +1,57 @@
.. only:: comment
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
.. _about:
Extensible Rewards
******************
Extensible Rewards differ from the previous reward mechanism used in PrimAITE v3.x as new reward
types can be added without requiring a change to the RewardFunction class in rewards.py (PrimAITE
core repository).
Changes to reward class structure.
==================================
Reward classes are inherited from AbstractReward (a sub-class of Pydantic's BaseModel).
Within the reward class there is a ConfigSchema class responsible for ensuring the config file data
is in the correct format. This also means there is little (if no) requirement for and `__init__`
method. The `.from_config` method is no longer required as it's inherited from `AbstractReward`.
Each class requires an identifier string which is used by the ConfigSchema class to verify that it
hasn't previously been added to the registry.
Inheriting from `BaseModel` removes the need for an `__init__` method but means that object
attributes need to be passed by keyword.
To add a new reward class follow the example below. Note that the type attribute in the
`ConfigSchema` class should match the type used in the config file to define the reward.
.. code-block:: Python
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
config: "DatabaseFileIntegrity.ConfigSchema"
location_in_state: List[str] = [""]
reward: float = 0.0
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for DatabaseFileIntegrity."""
type: str = "DATABASE_FILE_INTEGRITY"
node_hostname: str
folder_name: str
file_name: str
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
pass
Changes to YAML file.
=====================
.. code:: YAML
There's no longer a need to provide a `dns_server` as an option in the simulation section
of the config file.

View File

@@ -27,9 +27,10 @@ the structure:
service_ref: web_server_database_client
```
"""
from abc import abstractmethod
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
from pydantic import BaseModel
from typing_extensions import Never
from primaite import getLogger
@@ -42,25 +43,32 @@ _LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward:
class AbstractReward(BaseModel):
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
config: "AbstractReward.ConfigSchema"
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
# def __init__(self, schema_name, **kwargs):
# super.__init__(self, **kwargs)
# # Create ConfigSchema class
# self.config_class = type(schema_name, (BaseModel, ABC), **kwargs)
# self.config = self.config_class()
class ConfigSchema(BaseModel, ABC):
"""Config schema for AbstractReward."""
type: str
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate reward {identifier}")
cls._registry[identifier] = cls
@classmethod
@abstractmethod
def from_config(cls, config: dict) -> "AbstractReward":
def from_config(cls, config: Dict) -> "AbstractReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
@@ -68,11 +76,28 @@ class AbstractReward:
:return: The reward component.
:rtype: AbstractReward
"""
return cls()
if config["type"] not in cls._registry:
raise ValueError(f"Invalid reward type {config['type']}")
reward_class = cls._registry[config["type"]]
reward_obj = reward_class(config=reward_class.ConfigSchema(**config))
return reward_obj
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: Current simulation state
:type state: Dict
:param last_action_response: Current agent history state
:type last_action_response: AgentHistoryItem state
:return: Reward value
:rtype: float
"""
return 0.0
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
class DummyReward(AbstractReward, identifier="DUMMY"):
"""Dummy reward function component which always returns 0.0."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -86,41 +111,21 @@ class DummyReward(AbstractReward):
"""
return 0.0
@classmethod
def from_config(cls, config: dict) -> "DummyReward":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor. Should be empty.
:type config: dict
:return: The reward component.
:rtype: DummyReward
"""
return cls()
class DatabaseFileIntegrity(AbstractReward):
class DatabaseFileIntegrity(AbstractReward, identifier="DATABASE_FILE_INTEGRITY"):
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
"""Initialise the reward component.
config: "DatabaseFileIntegrity.ConfigSchema"
location_in_state: List[str] = [""]
reward: float = 0.0
:param node_hostname: Hostname of the node which contains the database file.
:type node_hostname: str
:param folder_name: folder which contains the database file.
:type folder_name: str
:param file_name: name of the database file.
:type file_name: str
"""
self.location_in_state = [
"network",
"nodes",
node_hostname,
"file_system",
"folders",
folder_name,
"files",
file_name,
]
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for DatabaseFileIntegrity."""
type: str = "DATABASE_FILE_INTEGRITY"
node_hostname: str
folder_name: str
file_name: str
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -132,6 +137,17 @@ class DatabaseFileIntegrity(AbstractReward):
:return: Reward value
:rtype: float
"""
self.location_in_state = [
"network",
"nodes",
self.config.node_hostname,
"file_system",
"folders",
self.config.folder_name,
"files",
self.config.file_name,
]
database_file_state = access_from_nested_dict(state, self.location_in_state)
if database_file_state is NOT_PRESENT_IN_STATE:
_LOGGER.debug(
@@ -148,44 +164,21 @@ class DatabaseFileIntegrity(AbstractReward):
else:
return 0
@classmethod
def from_config(cls, config: Dict) -> "DatabaseFileIntegrity":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:return: The reward component.
:rtype: DatabaseFileIntegrity
"""
node_hostname = config.get("node_hostname")
folder_name = config.get("folder_name")
file_name = config.get("file_name")
if not (node_hostname and folder_name and file_name):
msg = f"{cls.__name__} could not be initialised with parameters {config}"
_LOGGER.error(msg)
raise ValueError(msg)
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
class WebServer404Penalty(AbstractReward):
class WebServer404Penalty(AbstractReward, identifier="WEB_SERVER_404_PENALTY"):
"""Reward function component which penalises the agent when the web server returns a 404 error."""
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
"""Initialise the reward component.
config: "WebServer404Penalty.ConfigSchema"
location_in_state: List[str] = [""]
reward: float = 0.0
:param node_hostname: Hostname of the node which contains the web server service.
:type node_hostname: str
:param service_name: Name of the web server service.
:type service_name: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebServer404Penalty."""
type: str = "WEB_SERVER_404_PENALTY"
node_hostname: str
service_name: str
sticky: bool = True
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
@@ -197,6 +190,13 @@ class WebServer404Penalty(AbstractReward):
:return: Reward value
:rtype: float
"""
self.location_in_state = [
"network",
"nodes",
self.config.node_hostname,
"services",
self.config.service_name,
]
web_service_state = access_from_nested_dict(state, self.location_in_state)
# if webserver is no longer installed on the node, return 0
@@ -211,54 +211,27 @@ class WebServer404Penalty(AbstractReward):
return 1.0 if status == 200 else -1.0 if status == 404 else 0.0
self.reward = sum(map(status2rew, codes)) / len(codes) # convert form HTTP codes to rewards and average
elif not self.sticky: # there are no codes, but reward is not sticky, set reward to 0
elif not self.config.sticky: # there are no codes, but reward is not sticky, set reward to 0
self.reward = 0.0
else: # skip calculating if sticky and no new codes. instead, reuse last step's value
pass
return self.reward
@classmethod
def from_config(cls, config: Dict) -> "WebServer404Penalty":
"""Create a reward function component from a config dictionary.
:param config: dict of options for the reward component's constructor
:type config: Dict
:return: The reward component.
:rtype: WebServer404Penalty
"""
node_hostname = config.get("node_hostname")
service_name = config.get("service_name")
if not (node_hostname and service_name):
msg = (
f"{cls.__name__} could not be initialised from config because node_name and service_ref were not "
"found in reward config."
)
_LOGGER.warning(msg)
raise ValueError(msg)
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
class WebpageUnavailablePenalty(AbstractReward):
class WebpageUnavailablePenalty(AbstractReward, identifier="WEBPAGE_UNAVAILABLE_PENALTY"):
"""Penalises the agent when the web browser fails to fetch a webpage."""
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
config: "WebpageUnavailablePenalty.ConfigSchema"
reward: float = 0.0
location_in_state: List[str] = [""] # Calculate in __init__()?
:param node_hostname: Hostname of the node which has the web browser.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for WebpageUnavailablePenalty."""
type: str = "WEBPAGE_UNAVAILABLE_PENALTY"
node_hostname: str = ""
sticky: bool = True
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -274,6 +247,13 @@ class WebpageUnavailablePenalty(AbstractReward):
:return: Reward value
:rtype: float
"""
self.location_in_state = [
"network",
"nodes",
self.config.node_hostname,
"applications",
"WebBrowser",
]
web_browser_state = access_from_nested_dict(state, self.location_in_state)
if web_browser_state is NOT_PRESENT_IN_STATE:
@@ -283,14 +263,14 @@ class WebpageUnavailablePenalty(AbstractReward):
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
self.config.node_hostname,
"application",
"WebBrowser",
"execute",
]
# skip calculating if sticky and no new codes, reusing last step value
if not request_attempted and self.sticky:
if not request_attempted and self.config.sticky:
return self.reward
if last_action_response.response.status != "success":
@@ -298,7 +278,7 @@ class WebpageUnavailablePenalty(AbstractReward):
elif web_browser_state is NOT_PRESENT_IN_STATE or not web_browser_state["history"]:
_LOGGER.debug(
"Web browser reward could not be calculated because the web browser history on node",
f"{self._node} was not reported in the simulation state. Returning 0.0",
f"{self.config.node_hostname} was not reported in the simulation state. Returning 0.0",
)
self.reward = 0.0
else:
@@ -312,37 +292,19 @@ class WebpageUnavailablePenalty(AbstractReward):
return self.reward
@classmethod
def from_config(cls, config: dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"):
"""Penalises the agent when the green db clients fail to connect to the database."""
def __init__(self, node_hostname: str, sticky: bool = True) -> None:
"""
Initialise the reward component.
config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema"
reward: float = 0.0
:param node_hostname: Hostname of the node where the database client sits.
:type node_hostname: str
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
the reward if there were any responses this timestep.
:type sticky: bool
"""
self._node: str = node_hostname
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self.sticky: bool = sticky
self.reward: float = 0.0
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
class ConfigSchema(AbstractReward.ConfigSchema):
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
type: str = "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY"
node_hostname: str
sticky: bool = True
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
@@ -362,7 +324,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
request_attempted = last_action_response.request == [
"network",
"node",
self._node,
self.config.node_hostname,
"application",
"DatabaseClient",
"execute",
@@ -371,7 +333,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
if request_attempted: # if agent makes request, always recalculate fresh value
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
elif not self.sticky: # if no new request and not sticky, set reward to 0
elif not self.config.sticky: # if no new request and not sticky, set reward to 0
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
self.reward = 0.0
else: # if no new request and sticky, reuse reward value from last step
@@ -380,47 +342,30 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
return self.reward
@classmethod
def from_config(cls, config: Dict) -> AbstractReward:
"""
Build the reward component object from config.
:param config: Configuration dictionary.
:type config: Dict
"""
node_hostname = config.get("node_hostname")
sticky = config.get("sticky", True)
return cls(node_hostname=node_hostname, sticky=sticky)
class SharedReward(AbstractReward):
class SharedReward(AbstractReward, identifier="SHARED_REWARD"):
"""Adds another agent's reward to the overall reward."""
def __init__(self, agent_name: Optional[str] = None) -> None:
config: "SharedReward.ConfigSchema"
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for SharedReward."""
type: str = "SHARED_REWARD"
agent_name: str
def default_callback(agent_name: str) -> Never:
"""
Initialise the shared reward.
Default callback to prevent calling this reward until it's properly initialised.
The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
correctly.
:param agent_name: The name whose reward is an input
:type agent_name: Optional[str]
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
self.agent_name = agent_name
"""Agent whose reward to track."""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
def default_callback(agent_name: str) -> Never:
"""
Default callback to prevent calling this reward until it's properly initialised.
SharedReward should not be used until the game layer replaces self.callback with a reference to the
function that retrieves the desired agent's reward. Therefore, we define this default callback that raises
an error.
"""
raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.")
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it.
@@ -432,36 +377,20 @@ class SharedReward(AbstractReward):
:return: Reward value
:rtype: float
"""
return self.callback(self.agent_name)
@classmethod
def from_config(cls, config: Dict) -> "SharedReward":
"""
Build the SharedReward object from config.
:param config: Configuration dictionary
:type config: Dict
"""
agent_name = config.get("agent_name")
return cls(agent_name=agent_name)
return self.callback(self.config.agent_name)
class ActionPenalty(AbstractReward):
class ActionPenalty(AbstractReward, identifier="ACTION_PENALTY"):
"""Apply a negative reward when taking any action except DONOTHING."""
def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
"""
Initialise the reward.
config: "ActionPenalty.ConfigSchema"
Reward or penalise agents for doing nothing or taking actions.
class ConfigSchema(AbstractReward.ConfigSchema):
"""Config schema for ActionPenalty."""
:param action_penalty: Reward to give agents for taking any action except DONOTHING
:type action_penalty: float
:param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
:type do_nothing_penalty: float
"""
self.action_penalty = action_penalty
self.do_nothing_penalty = do_nothing_penalty
type: str = "ACTION_PENALTY"
action_penalty: float = -1.0
do_nothing_penalty: float = 0.0
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the penalty to be applied.
@@ -474,32 +403,14 @@ class ActionPenalty(AbstractReward):
:rtype: float
"""
if last_action_response.action == "DONOTHING":
return self.do_nothing_penalty
return self.config.do_nothing_penalty
else:
return self.action_penalty
@classmethod
def from_config(cls, config: Dict) -> "ActionPenalty":
"""Build the ActionPenalty object from config."""
action_penalty = config.get("action_penalty", -1.0)
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
return self.config.action_penalty
class RewardFunction:
"""Manages the reward function for the agent."""
rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
"WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty,
"GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty,
"SHARED_REWARD": SharedReward,
"ACTION_PENALTY": ActionPenalty,
}
"""List of reward class identifiers."""
def __init__(self):
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
@@ -534,7 +445,7 @@ class RewardFunction:
@classmethod
def from_config(cls, config: Dict) -> "RewardFunction":
"""Create a reward function from a config dictionary.
"""Create a reward function from a config dictionary and its related reward class.
:param config: dict of options for the reward manager's constructor
:type config: Dict
@@ -545,8 +456,11 @@ class RewardFunction:
for rew_component_cfg in config["reward_components"]:
rew_type = rew_component_cfg["type"]
# XXX: If options key is missing add key then add type key.
if "options" not in rew_component_cfg:
rew_component_cfg["options"] = {}
rew_component_cfg["options"]["type"] = rew_type
weight = rew_component_cfg.get("weight", 1.0)
rew_class = cls.rew_class_identifiers[rew_type]
rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {}))
rew_instance = AbstractReward.from_config(rew_component_cfg["options"])
new.register_component(component=rew_instance, weight=weight)
return new

View File

@@ -376,7 +376,7 @@ class PrimaiteGame:
if service_class is not None:
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_class)
new_node.software_manager.install(service_class, **service_cfg.get("options", {}))
new_service = new_node.software_manager.software[service_class.__name__]
# fixing duration for the service
@@ -629,7 +629,7 @@ class PrimaiteGame:
for comp, weight in agent.reward_function.reward_components:
if isinstance(comp, SharedReward):
comp: SharedReward
graph[name].add(comp.agent_name)
graph[name].add(comp.config.agent_name)
# while constructing the graph, we might as well set up the reward sharing itself.
comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward

View File

@@ -1780,10 +1780,11 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
"from primaite.utils.validation.port import PORT_LOOKUP\n",
"\n",
"# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n",
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n",
"c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=PROTOCOL_LOOKUP[\"UDP\"], masquerade_port=PORT_LOOKUP[\"DNS\"])\n",
"c2_beacon.establish()\n",
"c2_beacon.show()"
]
@@ -1804,7 +1805,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -1818,7 +1819,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -166,9 +166,10 @@
"from pathlib import Path\n",
"from primaite.simulator.system.applications.application import Application, ApplicationOperatingState\n",
"from primaite.simulator.system.software import SoftwareHealthState, SoftwareCriticality\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.file_system.file_system import FileSystem\n",
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
"from primaite.utils.validation.port import PORT_LOOKUP\n",
"\n",
"\n",
"# no applications exist yet so we will create our own.\n",
"class MSPaint(Application, identifier=\"MSPaint\"):\n",
@@ -182,7 +183,7 @@
"metadata": {},
"outputs": [],
"source": [
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
"mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=PORT_LOOKUP[\"HTTP\"], protocol = PROTOCOL_LOOKUP[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)"
]
},
{
@@ -249,7 +250,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": ".venv",
"language": "python",
"name": "python3"
},

View File

@@ -532,12 +532,12 @@
},
"outputs": [],
"source": [
"from primaite.simulator.network.transmission.network_layer import IPProtocol\n",
"from primaite.simulator.network.transmission.transport_layer import Port\n",
"from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n",
"from primaite.utils.validation.ip_protocol import PROTOCOL_LOOKUP\n",
"\n",
"network.get_node_by_hostname(\"router_1\").acl.add_rule(\n",
" action=ACLAction.DENY,\n",
" protocol=IPProtocol["ICMP"],\n",
" protocol=PROTOCOL_LOOKUP[\"ICMP\"],\n",
" src_ip_address=\"192.168.10.22\",\n",
" position=1\n",
")"
@@ -650,7 +650,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -664,7 +664,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -308,6 +308,9 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""
if not self._can_perform_action():
return None
if self.server_ip_address is None:
self.sys_log.warning(f"{self.name}: Database server IP address not provided.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None

View File

@@ -16,7 +16,7 @@ from primaite.utils.validation.ip_protocol import IPProtocol, PROTOCOL_LOOKUP
from primaite.utils.validation.port import Port, PORT_LOOKUP
if TYPE_CHECKING:
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog

View File

@@ -202,8 +202,6 @@ simulation:
port_scan_p_of_success: 0.8
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
- type: DNSServer
options:
domain_mapping:

View File

@@ -200,8 +200,6 @@ simulation:
port_scan_p_of_success: 0.8
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
- type: DNSServer
options:
domain_mapping:
@@ -232,8 +230,6 @@ simulation:
server_password: arcd
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
links:
- endpoint_a_hostname: switch_1

View File

@@ -209,7 +209,6 @@ simulation:
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
fix_duration: 3
- type: DNSServer
options:
@@ -250,8 +249,6 @@ simulation:
server_password: arcd
services:
- type: DNSClient
options:
dns_server: 192.168.1.10
links:
- endpoint_a_hostname: switch_1

View File

@@ -5,6 +5,7 @@ from primaite.config.load import get_extended_config_path
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from tests import TEST_ASSETS_ROOT
from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config
from tests.integration_tests.extensions.applications.extended_application import ExtendedApplication
from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
@@ -13,11 +14,12 @@ from tests.integration_tests.extensions.nodes.giga_switch import GigaSwitch
from tests.integration_tests.extensions.nodes.super_computer import SuperComputer
from tests.integration_tests.extensions.services.extended_service import ExtendedService
CONFIG_PATH = TEST_ASSETS_ROOT / "configs/extended_config.yaml"
def test_extended_example_config():
"""Test that the example config can be parsed properly."""
config_path = os.path.join("tests", "assets", "configs", "extended_config.yaml")
game = load_config(config_path)
game = load_config(CONFIG_PATH)
network: Network = game.simulation.network
assert len(network.nodes) == 10 # 10 nodes in example network

View File

@@ -18,12 +18,14 @@ from tests import TEST_ASSETS_ROOT
from tests.conftest import ControlledAgent
def test_WebpageUnavailablePenalty(game_and_agent):
def test_WebpageUnavailablePenalty(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that we get the right reward for failing to fetch a website."""
# set up the scenario, configure the web browser to the correct url
game, agent = game_and_agent
agent: ControlledAgent
comp = WebpageUnavailablePenalty(node_hostname="client_1")
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = WebpageUnavailablePenalty(config=schema)
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.run()
@@ -53,7 +55,7 @@ def test_WebpageUnavailablePenalty(game_and_agent):
assert agent.reward_function.current_reward == -0.7
def test_uc2_rewards(game_and_agent):
def test_uc2_rewards(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that the reward component correctly applies a penalty when the selected client cannot reach the database."""
game, agent = game_and_agent
agent: ControlledAgent
@@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent):
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
)
comp = GreenAdminDatabaseUnreachablePenalty("client_1")
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
comp = GreenAdminDatabaseUnreachablePenalty(config=schema)
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
response = game.simulation.apply_request(request)
@@ -139,15 +142,17 @@ def test_action_penalty_loads_from_config():
act_penalty_obj = comp[0]
if act_penalty_obj is None:
pytest.fail("Action penalty reward component was not added to the agent from config.")
assert act_penalty_obj.action_penalty == -0.75
assert act_penalty_obj.do_nothing_penalty == 0.125
assert act_penalty_obj.config.action_penalty == -0.75
assert act_penalty_obj.config.do_nothing_penalty == 0.125
def test_action_penalty():
"""Test that the action penalty is correctly applied when agent performs any action"""
# Create an ActionPenalty Reward
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
# Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
Penalty = ActionPenalty(config=schema)
# Assert that penalty is applied if action isn't DONOTHING
reward_value = Penalty.calculate(
@@ -178,11 +183,12 @@ def test_action_penalty():
assert reward_value == 0.125
def test_action_penalty_e2e(game_and_agent):
def test_action_penalty_e2e(game_and_agent: tuple[PrimaiteGame, ControlledAgent]):
"""Test that we get the right reward for doing actions to fetch a website."""
game, agent = game_and_agent
agent: ControlledAgent
comp = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
comp = ActionPenalty(config=schema)
agent.reward_function.register_component(comp, 1.0)

View File

@@ -11,7 +11,12 @@ from primaite.interface.request import RequestResponse
class TestWebServer404PenaltySticky:
def test_non_sticky(self):
reward = WebServer404Penalty("computer", "WebService", sticky=False)
schema = WebServer404Penalty.ConfigSchema(
node_hostname="computer",
service_name="WebService",
sticky=False,
)
reward = WebServer404Penalty(config=schema)
# no response codes yet, reward is 0
codes = []
@@ -38,7 +43,12 @@ class TestWebServer404PenaltySticky:
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = WebServer404Penalty("computer", "WebService", sticky=True)
schema = WebServer404Penalty.ConfigSchema(
node_hostname="computer",
service_name="WebService",
sticky=True,
)
reward = WebServer404Penalty(config=schema)
# no response codes yet, reward is 0
codes = []
@@ -67,7 +77,8 @@ class TestWebServer404PenaltySticky:
class TestWebpageUnavailabilitySticky:
def test_non_sticky(self):
reward = WebpageUnavailablePenalty("computer", sticky=False)
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=False)
reward = WebpageUnavailablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
@@ -127,7 +138,8 @@ class TestWebpageUnavailabilitySticky:
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = WebpageUnavailablePenalty("computer", sticky=True)
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="computer", sticky=True)
reward = WebpageUnavailablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
@@ -188,7 +200,11 @@ class TestWebpageUnavailabilitySticky:
class TestGreenAdminDatabaseUnreachableSticky:
def test_non_sticky(self):
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(
node_hostname="computer",
sticky=False,
)
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
@@ -214,7 +230,6 @@ class TestGreenAdminDatabaseUnreachableSticky:
# agent did nothing, because reward is not sticky, it goes back to 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
response = RequestResponse(status="success", data={})
browser_history = []
state = {"network": {"nodes": {"computer": {"applications": {"DatabaseClient": {}}}}}}
last_action_response = AgentHistoryItem(
timestep=0, action=action, parameters=params, request=request, response=response
@@ -244,7 +259,11 @@ class TestGreenAdminDatabaseUnreachableSticky:
assert reward.calculate(state, last_action_response) == -1.0
def test_sticky(self):
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(
node_hostname="computer",
sticky=True,
)
reward = GreenAdminDatabaseUnreachablePenalty(config=schema)
# no response codes yet, reward is 0
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]