#2064: added a method that checks if the class can perform actions and added it where necessary + tests everywhere
This commit is contained in:
@@ -48,6 +48,10 @@ class DatabaseService(Service):
|
||||
|
||||
def backup_database(self) -> bool:
|
||||
"""Create a backup of the database to the configured backup server."""
|
||||
# check if this action can be performed
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
# check if the backup server was configured
|
||||
if self.backup_server is None:
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
@@ -73,6 +77,10 @@ class DatabaseService(Service):
|
||||
|
||||
def restore_backup(self) -> bool:
|
||||
"""Restore a backup from backup server."""
|
||||
# check if this action can be performed
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
|
||||
@@ -51,13 +51,16 @@ class DNSClient(Service):
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
|
||||
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address) -> Union[bool, None]:
|
||||
"""
|
||||
Adds a domain name to the DNS Client cache.
|
||||
|
||||
:param: domain_name: The domain name to save to cache
|
||||
:param: ip_address: The IP Address to attach the domain name to
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
self.dns_cache[domain_name] = ip_address
|
||||
|
||||
def check_domain_exists(
|
||||
@@ -72,6 +75,9 @@ class DNSClient(Service):
|
||||
:param: session_id: The Session ID the payload is to originate from. Optional.
|
||||
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
# check if DNS server is configured
|
||||
if self.dns_server is None:
|
||||
self.sys_log.error(f"{self.name}: DNS Server is not configured")
|
||||
|
||||
@@ -48,6 +48,9 @@ class DNSServer(Service):
|
||||
:param target_domain: The single domain name requested by a DNS client.
|
||||
:return ip_address: The IP address of that domain name or None.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
return self.dns_table.get(target_domain)
|
||||
|
||||
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
|
||||
@@ -60,6 +63,9 @@ class DNSServer(Service):
|
||||
:param: domain_ip_address: The IP address that the domain should route to
|
||||
:type: domain_ip_address: IPv4Address
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return
|
||||
|
||||
self.dns_table[domain_name] = domain_ip_address
|
||||
|
||||
def reset_component_for_episode(self, episode: int):
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -41,6 +40,25 @@ class Service(IOSoftware):
|
||||
restart_countdown: Optional[int] = None
|
||||
"If currently restarting, how many timesteps remain until the restart is finished."
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the service can perform actions.
|
||||
|
||||
This is done by checking if the service is operating properly or the node it is installed
|
||||
in is operational.
|
||||
|
||||
Returns true if the software can perform actions.
|
||||
"""
|
||||
if not super()._can_perform_action():
|
||||
return False
|
||||
|
||||
if self.operating_state is not self.operating_state.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
Receives a payload from the SessionManager.
|
||||
@@ -108,8 +126,7 @@ class Service(IOSoftware):
|
||||
def start(self, **kwargs) -> None:
|
||||
"""Start the service."""
|
||||
# cant start the service if the node it is on is off
|
||||
if self.software_manager and self.software_manager.node.operating_state is not NodeOperatingState.ON:
|
||||
self.sys_log.error(f"Unable to start service. {self.software_manager.node.hostname} is not turned on.")
|
||||
if not super()._can_perform_action():
|
||||
return
|
||||
|
||||
if self.operating_state == ServiceOperatingState.STOPPED:
|
||||
|
||||
Reference in New Issue
Block a user