#2913: Fix remaining pydantic errors.
This commit is contained in:
@@ -28,7 +28,7 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar, Callable, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Never
|
||||
@@ -118,6 +118,12 @@ class DummyReward(AbstractReward, identifier="DummyReward"):
|
||||
class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
"""Reward function component which rewards the agent for maintaining the integrity of a database file."""
|
||||
|
||||
node_hostname: str
|
||||
folder_name: str
|
||||
file_name: str
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for DatabaseFileIntegrity."""
|
||||
|
||||
@@ -125,27 +131,6 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
folder_name: str
|
||||
file_name: str
|
||||
|
||||
def __init__(self, node_hostname: str, folder_name: str, file_name: str) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which contains the database file.
|
||||
:type node_hostname: str
|
||||
:param folder_name: folder which contains the database file.
|
||||
:type folder_name: str
|
||||
:param file_name: name of the database file.
|
||||
:type file_name: str
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
node_hostname,
|
||||
"file_system",
|
||||
"folders",
|
||||
folder_name,
|
||||
"files",
|
||||
file_name,
|
||||
]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
@@ -156,6 +141,17 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.node_hostname,
|
||||
"file_system",
|
||||
"folders",
|
||||
self.folder_name,
|
||||
"files",
|
||||
self.file_name,
|
||||
]
|
||||
|
||||
database_file_state = access_from_nested_dict(state, self.location_in_state)
|
||||
if database_file_state is NOT_PRESENT_IN_STATE:
|
||||
_LOGGER.debug(
|
||||
@@ -195,6 +191,12 @@ class DatabaseFileIntegrity(AbstractReward, identifier="DatabaseFileIntegrity"):
|
||||
class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
"""Reward function component which penalises the agent when the web server returns a 404 error."""
|
||||
|
||||
node_hostname: str
|
||||
service_name: str
|
||||
sticky: bool = True
|
||||
location_in_state: List[str] = [""]
|
||||
reward: float = 0.0
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
"""ConfigSchema for WebServer404Penalty."""
|
||||
|
||||
@@ -202,22 +204,6 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
service_name: str
|
||||
sticky: bool = True
|
||||
|
||||
def __init__(self, node_hostname: str, service_name: str, sticky: bool = True) -> None:
|
||||
"""Initialise the reward component.
|
||||
|
||||
:param node_hostname: Hostname of the node which contains the web server service.
|
||||
:type node_hostname: str
|
||||
:param service_name: Name of the web server service.
|
||||
:type service_name: str
|
||||
:param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
the reward if there were any responses this timestep.
|
||||
:type sticky: bool
|
||||
"""
|
||||
self.sticky: bool = sticky
|
||||
self.reward: float = 0.0
|
||||
"""Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the reward for the current state.
|
||||
|
||||
@@ -228,6 +214,13 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.node_hostname,
|
||||
"services",
|
||||
self.service_name,
|
||||
]
|
||||
web_service_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
# if webserver is no longer installed on the node, return 0
|
||||
@@ -274,6 +267,7 @@ class WebServer404Penalty(AbstractReward, identifier="WebServer404Penalty"):
|
||||
|
||||
class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePenalty"):
|
||||
"""Penalises the agent when the web browser fails to fetch a webpage."""
|
||||
|
||||
node_hostname: str = ""
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
@@ -287,22 +281,6 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
sticky: bool = True
|
||||
reward: float = 0.0
|
||||
|
||||
# def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
# """
|
||||
# Initialise the reward component.
|
||||
|
||||
# :param node_hostname: Hostname of the node which has the web browser.
|
||||
# :type node_hostname: str
|
||||
# :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
# the reward if there were any responses this timestep.
|
||||
# :type sticky: bool
|
||||
# """
|
||||
# self._node: str = node_hostname
|
||||
# self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
|
||||
# self.sticky: bool = sticky
|
||||
# self.reward: float = 0.0
|
||||
# """Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
@@ -317,7 +295,13 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
:return: Reward value
|
||||
:rtype: float
|
||||
"""
|
||||
self.location_in_state: List[str] = ["network", "nodes", self.node_hostname, "applications", "WebBrowser"]
|
||||
self.location_in_state = [
|
||||
"network",
|
||||
"nodes",
|
||||
self.node_hostname,
|
||||
"applications",
|
||||
"WebBrowser",
|
||||
]
|
||||
web_browser_state = access_from_nested_dict(state, self.location_in_state)
|
||||
|
||||
if web_browser_state is NOT_PRESENT_IN_STATE:
|
||||
@@ -371,6 +355,7 @@ class WebpageUnavailablePenalty(AbstractReward, identifier="WebpageUnavailablePe
|
||||
|
||||
class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdminDatabaseUnreachablePenalty"):
|
||||
"""Penalises the agent when the green db clients fail to connect to the database."""
|
||||
|
||||
node_hostname: str = ""
|
||||
_node: str = node_hostname
|
||||
sticky: bool = True
|
||||
@@ -382,22 +367,6 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
node_hostname: str
|
||||
sticky: bool = True
|
||||
|
||||
# def __init__(self, node_hostname: str, sticky: bool = True) -> None:
|
||||
# """
|
||||
# Initialise the reward component.
|
||||
|
||||
# :param node_hostname: Hostname of the node where the database client sits.
|
||||
# :type node_hostname: str
|
||||
# :param sticky: If True, calculate the reward based on the most recent response status. If False, only calculate
|
||||
# the reward if there were any responses this timestep.
|
||||
# :type sticky: bool
|
||||
# """
|
||||
# self._node: str = node_hostname
|
||||
# self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
|
||||
# self.sticky: bool = sticky
|
||||
# self.reward: float = 0.0
|
||||
# """Reward value calculated last time any responses were seen. Used for persisting sticky rewards."""
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""
|
||||
Calculate the reward based on current simulation state, and the recent agent action.
|
||||
@@ -449,6 +418,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward, identifier="GreenAdmi
|
||||
|
||||
class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
"""Adds another agent's reward to the overall reward."""
|
||||
|
||||
agent_name: str
|
||||
|
||||
class ConfigSchema(AbstractReward.ConfigSchema):
|
||||
@@ -456,19 +426,6 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
|
||||
agent_name: str
|
||||
|
||||
# def __init__(self, agent_name: Optional[str] = None) -> None:
|
||||
# """
|
||||
# Initialise the shared reward.
|
||||
|
||||
# The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work
|
||||
# correctly.
|
||||
|
||||
# :param agent_name: The name whose reward is an input
|
||||
# :type agent_name: Optional[str]
|
||||
# """
|
||||
# # self.agent_name = agent_name
|
||||
# """Agent whose reward to track."""
|
||||
|
||||
def default_callback(agent_name: str) -> Never:
|
||||
"""
|
||||
Default callback to prevent calling this reward until it's properly initialised.
|
||||
@@ -508,6 +465,7 @@ class SharedReward(AbstractReward, identifier="SharedReward"):
|
||||
|
||||
class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
"""Apply a negative reward when taking any action except DONOTHING."""
|
||||
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
@@ -517,21 +475,6 @@ class ActionPenalty(AbstractReward, identifier="ActionPenalty"):
|
||||
action_penalty: float = -1.0
|
||||
do_nothing_penalty: float = 0.0
|
||||
|
||||
# def __init__(self, action_penalty: float, do_nothing_penalty: float) -> None:
|
||||
# """
|
||||
# Initialise the reward.
|
||||
|
||||
# Reward or penalise agents for doing nothing or taking actions.
|
||||
|
||||
# :param action_penalty: Reward to give agents for taking any action except DONOTHING
|
||||
# :type action_penalty: float
|
||||
# :param do_nothing_penalty: Reward to give agent for taking the DONOTHING action
|
||||
# :type do_nothing_penalty: float
|
||||
# """
|
||||
# super().__init__(action_penalty=action_penalty, do_nothing_penalty=do_nothing_penalty)
|
||||
# self.action_penalty = action_penalty
|
||||
# self.do_nothing_penalty = do_nothing_penalty
|
||||
|
||||
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
|
||||
"""Calculate the penalty to be applied.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite.interface.request import RequestResponse
|
||||
|
||||
class TestWebServer404PenaltySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=False)
|
||||
reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
@@ -38,7 +38,7 @@ class TestWebServer404PenaltySticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebServer404Penalty("computer", "WebService", sticky=True)
|
||||
reward = WebServer404Penalty(node_hostname="computer", service_name="WebService", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
codes = []
|
||||
@@ -67,7 +67,7 @@ class TestWebServer404PenaltySticky:
|
||||
|
||||
class TestWebpageUnavailabilitySticky:
|
||||
def test_non_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=False)
|
||||
reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
@@ -127,7 +127,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = WebpageUnavailablePenalty("computer", sticky=True)
|
||||
reward = WebpageUnavailablePenalty(node_hostname="computer", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
@@ -188,7 +188,7 @@ class TestWebpageUnavailabilitySticky:
|
||||
|
||||
class TestGreenAdminDatabaseUnreachableSticky:
|
||||
def test_non_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=False)
|
||||
reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=False)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
@@ -244,7 +244,7 @@ class TestGreenAdminDatabaseUnreachableSticky:
|
||||
assert reward.calculate(state, last_action_response) == -1.0
|
||||
|
||||
def test_sticky(self):
|
||||
reward = GreenAdminDatabaseUnreachablePenalty("computer", sticky=True)
|
||||
reward = GreenAdminDatabaseUnreachablePenalty(node_hostname="computer", sticky=True)
|
||||
|
||||
# no response codes yet, reward is 0
|
||||
action, params, request = "DO_NOTHING", {}, ["DONOTHING"]
|
||||
|
||||
Reference in New Issue
Block a user