diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index e165c9ad..4cb31d25 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -17,6 +17,7 @@ from gymnasium import spaces from pydantic import BaseModel, Field, field_validator, ValidationInfo from primaite import getLogger +from primaite.interface.request import RequestFormat _LOGGER = getLogger(__name__) @@ -245,6 +246,27 @@ class NodeApplicationInstallAction(AbstractAction): ] +class ConfigureDatabaseClientAction(AbstractAction): + """Action which sets config parameters for a database client on a node.""" + + class _Opts(BaseModel): + """Schema for options that can be passed to this action.""" + + server_ip_address: Optional[str] = None + server_password: Optional[str] = None + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, options: Dict) -> RequestFormat: + """Return the action formatted as a request that can be ingested by the simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + if node_name is None: + return ["do_nothing"] + ConfigureDatabaseClientAction._Opts.model_validate(options) # check that options adhere to schema + return ["network", "node", node_name, "application", "DatabaseClient", "configure", options] + + class NodeApplicationRemoveAction(AbstractAction): """Action which removes/uninstalls an application.""" @@ -1045,6 +1067,7 @@ class ActionManager: "NODE_NMAP_PING_SCAN": NodeNMAPPingScanAction, "NODE_NMAP_PORT_SCAN": NodeNMAPPortScanAction, "NODE_NMAP_NETWORK_SERVICE_RECON": NodeNetworkServiceReconAction, + "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index bae2139b..6396c678 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -8,13 +8,14 @@ from uuid import uuid4 from prettytable import MARKDOWN, PrettyTable from pydantic import BaseModel -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.software_manager import SoftwareManager +from primaite.utils.validators import IPV4Address class DatabaseClientConnection(BaseModel): @@ -96,6 +97,14 @@ class DatabaseClient(Application): """ rm = super()._init_request_manager() rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute()))) + + def _configure(request: RequestFormat, context: Dict) -> RequestResponse: + ip, pw = request[-1].get("server_ip_address"), request[-1].get("server_password") + ip = None if ip is None else IPV4Address(ip) + success = self.configure(server_ip_address=ip, server_password=pw) + return RequestResponse.from_bool(success) + + rm.add_request("configure", RequestType(func=lambda request, context: _configure(request, context))) return rm def execute(self) -> bool: @@ -141,16 +150,17 @@ class DatabaseClient(Application): table.add_row([connection_id, connection.is_active]) print(table.get_string(sortby="Connection ID")) - def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): + def configure(self, server_ip_address: Optional[IPv4Address] = None, server_password: Optional[str] = None) -> bool: """ Configure the DatabaseClient to communicate with a DatabaseService. :param server_ip_address: The IP address of the Node the DatabaseService is on. :param server_password: The password on the DatabaseService. """ - self.server_ip_address = server_ip_address - self.server_password = server_password + self.server_ip_address = server_ip_address or self.server_ip_address + self.server_password = server_password or self.server_password self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.") + return True def connect(self) -> bool: """Connect the native client connection.""" diff --git a/tests/integration_tests/game_layer/observations/actions/__init__.py b/tests/integration_tests/game_layer/observations/actions/__init__.py new file mode 100644 index 00000000..be6c00e7 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/actions/__init__.py @@ -0,0 +1 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK diff --git a/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py b/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py new file mode 100644 index 00000000..17e262d1 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/actions/test_configure_actions.py @@ -0,0 +1,85 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from ipaddress import IPv4Address + +from primaite.game.agent.actions import ConfigureDatabaseClientAction +from primaite.simulator.system.applications.database_client import DatabaseClient +from tests.conftest import ControlledAgent + + +class TestConfigureDatabaseAction: + def test_configure_ip_password(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + + action = ( + "CONFIGURE_DATABASE_CLIENT", + { + "node_id": 0, + "options": { + "server_ip_address": "192.168.1.99", + "server_password": "admin123", + }, + }, + ) + agent.store_action(action) + game.step() + + assert db_client.server_ip_address == IPv4Address("192.168.1.99") + assert db_client.server_password == "admin123" + + def test_configure_ip(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + + action = ( + "CONFIGURE_DATABASE_CLIENT", + { + "node_id": 0, + "options": { + "server_ip_address": "192.168.1.99", + }, + }, + ) + agent.store_action(action) + game.step() + + assert db_client.server_ip_address == IPv4Address("192.168.1.99") + assert db_client.server_password is None + + def test_configure_password(self, game_and_agent): + game, agent = game_and_agent + agent: ControlledAgent + agent.action_manager.actions["CONFIGURE_DATABASE_CLIENT"] = ConfigureDatabaseClientAction(agent.action_manager) + + # make sure there is a database client on this node + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software["DatabaseClient"] + old_ip = db_client.server_ip_address + + action = ( + "CONFIGURE_DATABASE_CLIENT", + { + "node_id": 0, + "options": { + "server_password": "admin123", + }, + }, + ) + agent.store_action(action) + game.step() + + assert db_client.server_ip_address == old_ip + assert db_client.server_password is "admin123"