#2913: Remove from_config() and refactor (WIP).
This commit is contained in:
@@ -46,6 +46,14 @@ WhereType = Optional[Iterable[Union[str, int]]]
|
||||
class AbstractReward(BaseModel):
|
||||
"""Base class for reward function components."""
|
||||
|
||||
config: "AbstractReward.ConfigSchema"
|
||||
|
||||
# def __init__(self, schema_name, **kwargs):
|
||||
# super.__init__(self, **kwargs)
|
||||
# # Create ConfigSchema class
|
||||
# self.config_class = type(schema_name, (BaseModel, ABC), **kwargs)
|
||||
# self.config = self.config_class()
|
||||
|
||||
class ConfigSchema(BaseModel, ABC):
|
||||
"""Config schema for AbstractReward."""
|
||||
|
||||
@@ -56,7 +64,7 @@ class AbstractReward(BaseModel):
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate node adder {identifier}")
|
||||
raise ValueError(f"Duplicate reward {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@classmethod
|
||||
@@ -70,9 +78,10 @@ class AbstractReward(BaseModel):
|
||||
"""
|
||||
if config["type"] not in cls._registry:
|
||||
raise ValueError(f"Invalid reward type {config['type']}")
|
||||
adder_class = cls._registry[config["type"]]
|
||||
adder_class.add_nodes_to_net(config=adder_class.ConfigSchema(**config))
|
||||
return cls
|
||||
reward_class = cls._registry[config["type"]]
|
||||
reward_config = reward_class.ConfigSchema(**config)
|
||||
reward_class(config=reward_config)
|
||||
return reward_class
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
@@ -103,30 +112,18 @@ class DummyReward(AbstractReward, identifier="DummyReward"):
|
||||
"""
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "DummyReward":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor. Should be empty.
|
||||
:type config: dict
|
||||
:return: The reward component.
|
||||
:rtype: DummyReward
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
node_hostname: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
config: "DatabaseFileIntegrity.ConfigSchema"
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseFileIntegrity."""
|
||||
|
||||
type: str = "DatabaseFileIntegrity"
|
||||
node_hostname: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
@@ -144,12 +141,12 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.node_hostname,
|
||||
self.config.node_hostname,
|
||||
"file_system",
|
||||
"folders",
|
||||
self.folder_name,
|
||||
self.config.folder_name,
|
||||
"files",
|
||||
self.file_name,
|
||||
self.config.file_name,
|
||||
]
|
||||
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
@@ -168,38 +165,18 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
else:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "DatabaseFileIntegrity":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:return: The reward component.
|
||||
:rtype: DatabaseFileIntegrity
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
folder_name = config.get("folder_name")
|
||||
file_name = config.get("file_name")
|
||||
if not (node_hostname and folder_name and file_name):
|
||||
msg = f"{cls.__name__} could not be initialised with parameters {config}"
|
||||
_LOGGER.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(node_hostname=node_hostname, folder_name=folder_name, file_name=file_name)
|
||||
|
||||
|
||||
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
node_hostname: str
|
||||
service_name: str
|
||||
sticky: bool = True
|
||||
config: "WebServer404Penalty.ConfigSchema"
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebServer404Penalty."""
|
||||
|
||||
type: str = "WebServer404Penalty"
|
||||
node_hostname: str
|
||||
service_name: str
|
||||
sticky: bool = True
|
||||
@@ -217,9 +194,9 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.node_hostname,
|
||||
self.config.node_hostname,
|
||||
"services",
|
||||
self.service_name,
|
||||
self.config.service_name,
|
||||
]
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
@@ -242,43 +219,20 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "WebServer404Penalty":
|
||||
"""Create a reward function component from a config dictionary.
|
||||
|
||||
:param config: dict of options for the reward component's constructor
|
||||
:type config: Dict
|
||||
:return: The reward component.
|
||||
:rtype: WebServer404Penalty
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
service_name = config.get("service_name")
|
||||
if not (node_hostname and service_name):
|
||||
msg = (
|
||||
f"{cls.__name__} could not be initialised from config because node_name and service_ref were not "
|
||||
"found in reward config."
|
||||
)
|
||||
_LOGGER.warning(msg)
|
||||
raise ValueError(msg)
|
||||
sticky = config.get("sticky", True)
|
||||
|
||||
return cls(node_hostname=node_hostname, service_name=service_name, sticky=sticky)
|
||||
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
location_in_state: List[str] = [""]
|
||||
config: "WebpageUnavailablePenalty.ConfigSchema"
|
||||
reward: float = 0.0 # XXX: Private attribute?
|
||||
location_in_state: List[str] = [""] # Calculate in __init__()?
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebpageUnavailablePenalty."""
|
||||
|
||||
type: str = "WebpageUnavailablePenalty"
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
@@ -297,7 +251,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.node_hostname,
|
||||
self.config.node_hostname,
|
||||
"applications",
|
||||
"WebBrowser",
|
||||
]
|
||||
@@ -310,14 +264,14 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self.node_hostname,
|
||||
self.config.node_hostname,
|
||||
"application",
|
||||
"WebBrowser",
|
||||
"execute",
|
||||
]
|
||||
|
||||
# skip calculating if sticky and no new codes, reusing last step value
|
||||
if not request_attempted and self.sticky:
|
||||
if not request_attempted and self.config.sticky:
|
||||
return self.reward
|
||||
|
||||
if last_action_response.response.status != "success":
|
||||
@@ -339,29 +293,17 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> AbstractReward:
|
||||
"""
|
||||
Build the reward component object from config.
|
||||
|
||||
:param config: Configuration dictionary.
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
config: "GreenAdminDatabaseUnreachablePenalty.ConfigSchema"
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for GreenAdminDatabaseUnreachablePenalty."""
|
||||
|
||||
type: str = "GreenAdminDatabaseUnreachablePenalty"
|
||||
node_hostname: str
|
||||
sticky: bool = True
|
||||
|
||||
@@ -383,7 +325,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
request_attempted = last_action_response.request == [
|
||||
"network",
|
||||
"node",
|
||||
self.node_hostname,
|
||||
self.config.node_hostname,
|
||||
"application",
|
||||
"DatabaseClient",
|
||||
"execute",
|
||||
@@ -392,7 +334,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
if request_attempted: # if agent makes request, always recalculate fresh value
|
||||
last_action_response.reward_info = {"connection_attempt_status": last_action_response.response.status}
|
||||
self.reward = 1.0 if last_action_response.response.status == "success" else -1.0
|
||||
elif not self.sticky: # if no new request and not sticky, set reward to 0
|
||||
elif not self.config.sticky: # if no new request and not sticky, set reward to 0
|
||||
last_action_response.reward_info = {"connection_attempt_status": "n/a"}
|
||||
self.reward = 0.0
|
||||
else: # if no new request and sticky, reuse reward value from last step
|
||||
@@ -401,27 +343,16 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
|
||||
return self.reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> AbstractReward:
|
||||
"""
|
||||
Build the reward component object from config.
|
||||
|
||||
:param config: Configuration dictionary.
|
||||
:type config: Dict
|
||||
"""
|
||||
node_hostname = config.get("node_hostname")
|
||||
sticky = config.get("sticky", True)
|
||||
return cls(node_hostname=node_hostname, sticky=sticky)
|
||||
|
||||
|
||||
class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
agent_name: str
|
||||
config: "SharedReward.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for SharedReward."""
|
||||
|
||||
type: str = "SharedReward"
|
||||
agent_name: str
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
@@ -447,29 +378,18 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
return self.callback(self.agent_name)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "SharedReward":
|
||||
"""
|
||||
Build the SharedReward object from config.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:type config: Dict
|
||||
"""
|
||||
agent_name = config.get("agent_name")
|
||||
return cls(agent_name=agent_name)
|
||||
return self.callback(self.config.agent_name)
|
||||
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
config: "ActionPenalty.ConfigSchema"
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""Config schema for ActionPenalty."""
|
||||
|
||||
type: str = "ActionPenalty"
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
@@ -484,16 +404,9 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
:rtype: float
|
||||
"""
|
||||
if last_action_response.action == "DONOTHING":
|
||||
return self.do_nothing_penalty
|
||||
return self.config.do_nothing_penalty
|
||||
else:
|
||||
return self.action_penalty
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "ActionPenalty":
|
||||
"""Build the ActionPenalty object from config."""
|
||||
action_penalty = config.get("action_penalty", -1.0)
|
||||
do_nothing_penalty = config.get("do_nothing_penalty", 0.0)
|
||||
return cls(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
|
||||
return self.config.action_penalty
|
||||
|
||||
|
||||
class RewardFunction:
|
||||
|
||||
@@ -23,7 +23,9 @@ def test_WebpageUnavailablePenalty(game_and_agent):
|
||||
# set up the scenario, configure the web browser to the correct url
|
||||
game, agent = game_and_agent
|
||||
agent: ControlledAgent
|
||||
comp = WebpageUnavailablePenalty(node_hostname="client_1")
|
||||
schema = WebpageUnavailablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
|
||||
comp = WebpageUnavailablePenalty(config=schema)
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
|
||||
browser.run()
|
||||
@@ -74,7 +76,8 @@ def test_uc2_rewards(game_and_agent):
|
||||
ACLAction.PERMIT, src_port=PORT_LOOKUP["POSTGRES_SERVER"], dst_port=PORT_LOOKUP["POSTGRES_SERVER"], position=2
|
||||
)
|
||||
|
||||
comp = GreenAdminDatabaseUnreachablePenalty(node_hostname="client_1")
|
||||
schema = GreenAdminDatabaseUnreachablePenalty.ConfigSchema(node_hostname="client_1", sticky=True)
|
||||
comp = GreenAdminDatabaseUnreachablePenalty(config=schema)
|
||||
|
||||
request = ["network", "node", "client_1", "application", "DatabaseClient", "execute"]
|
||||
response = game.simulation.apply_request(request)
|
||||
@@ -147,7 +150,9 @@ def test_action_penalty():
|
||||
"""Test that the action penalty is correctly applied when agent performs any action"""
|
||||
|
||||
# Create an ActionPenalty Reward
|
||||
Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
schema = ActionPenalty.ConfigSchema(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
# Penalty = ActionPenalty(action_penalty=-0.75, do_nothing_penalty=0.125)
|
||||
Penalty = ActionPenalty(schema)
|
||||
|
||||
# Assert that penalty is applied if action isn't DONOTHING
|
||||
reward_value = Penalty.calculate(
|
||||
|
||||
Reference in New Issue
Block a user