From ed01293b862cb28fc7d13b9d04e994d98ca663cb Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 1 Mar 2024 16:02:27 +0000 Subject: [PATCH] Make db admin reward persistent --- src/primaite/game/agent/rewards.py | 8 +++++--- src/primaite/game/game.py | 2 +- .../simulator/system/applications/database_client.py | 9 ++++----- .../simulator/system/applications/web_browser.py | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 4eb1ab3f..882ad024 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -263,11 +263,13 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): :type state: Dict """ db_state = access_from_nested_dict(state, self.location_in_state) - if db_state is NOT_PRESENT_IN_STATE or "connections_status" not in db_state: + if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") - connections_status = db_state["connections_status"] - if False in connections_status: + last_connection_successful = db_state["last_connection_successful"] + if last_connection_successful is False: return -1.0 + elif last_connection_successful is True: + return 1.0 return 0 @classmethod diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index b9f92d3a..cf21dd40 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -296,7 +296,7 @@ class PrimaiteGame: if service_type == "DatabaseService": if "options" in service_cfg: opt = service_cfg["options"] - new_service.password = opt.get("backup_server_ip", None) + new_service.password = opt.get("db_password", None) new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) if service_type == "FTPServer": if "options" in service_cfg: diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index fe8180d7..addad35a 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -1,5 +1,5 @@ from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from uuid import uuid4 from primaite import getLogger @@ -26,7 +26,7 @@ class DatabaseClient(Application): server_password: Optional[str] = None connected: bool = False _query_success_tracker: Dict[str, bool] = {} - _connections_status: List[bool] = [] + _last_connection_successful: Optional[bool] = None """Keep track of connections that were established or verified during this step. Used for rewards.""" def __init__(self, **kwargs): @@ -46,7 +46,7 @@ class DatabaseClient(Application): can_connect = self.connect(connection_id=list(self.connections.keys())[-1]) else: can_connect = self.connect() - self._connections_status.append(can_connect) + self._last_connection_successful = can_connect return can_connect def describe_state(self) -> Dict: @@ -57,8 +57,7 @@ class DatabaseClient(Application): """ state = super().describe_state() # list of connections that were established or verified during this step. - state["connections_status"] = [c for c in self._connections_status] - self._connections_status.clear() + state["last_connection_successful"] = self._last_connection_successful return state def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index 6f2c479c..9fa86328 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -199,7 +199,7 @@ class WebBrowser(Application): def state(self) -> Dict: """Return the contents of this dataclass as a dict for use with describe_state method.""" if self.status == self._HistoryItemStatus.LOADED: - outcome = self.response_code + outcome = self.response_code.value else: outcome = self.status.value return {"url": self.url, "outcome": outcome}