From 173f110fb248911c7e5748e585e67c0fbbc04b15 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 24 Jul 2024 16:38:06 +0100 Subject: [PATCH 01/12] #2769: initial commit of user account actions --- pyproject.toml | 2 +- src/primaite/game/agent/actions.py | 36 ++++++++++++++++++ .../simulator/network/hardware/base.py | 9 +++++ tests/conftest.py | 3 ++ .../test_remote_user_account_actions.py | 38 +++++++++++++++++++ 5 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py diff --git a/pyproject.toml b/pyproject.toml index 9e919604..f63ee4c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ - "ray[rllib] >= 2.20.0, < 3", + "ray[rllib] >= 2.32.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", "sb3-contrib==2.1.0", diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 9a5fedc9..bf8e4323 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1071,6 +1071,39 @@ class NodeNetworkServiceReconAction(AbstractAction): ] +class NodeAccountsChangePasswordAction(AbstractAction): + """Action which changes the password for a user.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + pass + + +class NodeSessionsRemoteLoginAction(AbstractAction): + """Action which performs a remote session login.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + pass + + def form_request(self, node_id: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", node_id, "remote_logon"] + + +class NodeSessionsRemoteLogoutAction(AbstractAction): + """Action which performs a remote session logout.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + pass + + def form_request(self, node_id: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", node_id, "remote_logoff"] + + class ActionManager: """Class which manages the action space for an agent.""" @@ -1122,6 +1155,9 @@ class ActionManager: "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, "CONFIGURE_DOSBOT": ConfigureDoSBotAction, + "NODE_ACCOUNTS_CHANGEPASSWORD": NodeAccountsChangePasswordAction, + "NODE_SESSIONS_REMOTE_LOGIN": NodeSessionsRemoteLoginAction, + "NODE_SESSIONS_REMOTE_LOGOUT": NodeSessionsRemoteLogoutAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 15c44821..831a8539 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1072,6 +1072,15 @@ class Node(SimComponent): "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request + rm.add_request( + "remote_logon", + RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), + ) # TODO implement remote_logon request + rm.add_request( + "remote_logoff", + RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), + ) # TODO implement remote_logoff request + self._os_request_manager = RequestManager() self._os_request_manager.add_request( "scan", diff --git a/tests/conftest.py b/tests/conftest.py index 54519e2b..e1ce41b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -458,6 +458,9 @@ def game_and_agent(): {"type": "HOST_NIC_DISABLE"}, {"type": "NETWORK_PORT_ENABLE"}, {"type": "NETWORK_PORT_DISABLE"}, + {"type": "NODE_ACCOUNTS_CHANGEPASSWORD"}, + {"type": "NODE_SESSIONS_REMOTE_LOGIN"}, + {"type": "NODE_SESSIONS_REMOTE_LOGOUT"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py new file mode 100644 index 00000000..807715bb --- /dev/null +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py @@ -0,0 +1,38 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + + +def test_remote_logon(game_and_agent): + """Test that the remote session login action works.""" + game, agent = game_and_agent + + action = ( + "NODE_SESSIONS_REMOTE_LOGIN", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert that there is a logged in user + + +def test_remote_logoff(game_and_agent): + """Test that the remote session logout action works.""" + game, agent = game_and_agent + + action = ( + "NODE_SESSIONS_REMOTE_LOGIN", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert that there is a logged in user + + action = ( + "NODE_SESSIONS_REMOTE_LOGOUT", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert the user has logged out From df50ec8abc65d08961cc49716925e6296f68b787 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 25 Jul 2024 10:02:32 +0100 Subject: [PATCH 02/12] #2769: add change password action --- src/primaite/game/agent/actions.py | 8 ++++---- src/primaite/simulator/network/hardware/base.py | 5 ++++- .../test_user_account_change_password.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index bf8e4323..266c667b 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1077,16 +1077,16 @@ class NodeAccountsChangePasswordAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self) -> RequestFormat: + def form_request(self, node_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - pass + return ["network", "node", node_id, "change_password"] class NodeSessionsRemoteLoginAction(AbstractAction): """Action which performs a remote session login.""" def __init__(self, manager: "ActionManager", **kwargs) -> None: - pass + super().__init__(manager=manager) def form_request(self, node_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -1097,7 +1097,7 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): """Action which performs a remote session logout.""" def __init__(self, manager: "ActionManager", **kwargs) -> None: - pass + super().__init__(manager=manager) def form_request(self, node_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 831a8539..3ef33ac3 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1071,7 +1071,10 @@ class Node(SimComponent): rm.add_request( "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request - + rm.add_request( + "change_password", + RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), + ) # TODO implement change_password request rm.add_request( "remote_logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py new file mode 100644 index 00000000..27328100 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py @@ -0,0 +1,13 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +def test_remote_logon(game_and_agent): + """Test that the remote session login action works.""" + game, agent = game_and_agent + + action = ( + "NODE_ACCOUNTS_CHANGEPASSWORD", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert that the user account password is changed From 7b523d9450a3e7fef252458d4fbe0fe9e0f4928c Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 30 Jul 2024 11:33:52 +0100 Subject: [PATCH 03/12] #2769: added changes which should align with 2735 once merged --- src/primaite/game/agent/actions.py | 12 +++++----- .../simulator/network/hardware/base.py | 12 ---------- .../test_remote_user_account_actions.py | 23 ++++++++++++++----- .../test_user_account_change_password.py | 14 +++++++++-- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 266c667b..19442818 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1077,9 +1077,9 @@ class NodeAccountsChangePasswordAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str) -> RequestFormat: + def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "change_password"] + return ["network", "node", node_id, "accounts", "change_password", username, current_password, new_password] class NodeSessionsRemoteLoginAction(AbstractAction): @@ -1088,9 +1088,9 @@ class NodeSessionsRemoteLoginAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str) -> RequestFormat: + def form_request(self, node_id: str, username: str, password: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "remote_logon"] + return ["network", "node", node_id, "sessions", "remote_login", username, password] class NodeSessionsRemoteLogoutAction(AbstractAction): @@ -1099,9 +1099,9 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str) -> RequestFormat: + def form_request(self, node_id: str, remote_session_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "remote_logoff"] + return ["network", "node", node_id, "sessions", "remote_logout", remote_session_id] class ActionManager: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 3ef33ac3..15c44821 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1071,18 +1071,6 @@ class Node(SimComponent): rm.add_request( "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request - rm.add_request( - "change_password", - RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), - ) # TODO implement change_password request - rm.add_request( - "remote_logon", - RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), - ) # TODO implement remote_logon request - rm.add_request( - "remote_logoff", - RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), - ) # TODO implement remote_logoff request self._os_request_manager = RequestManager() self._os_request_manager.add_request( diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py index 807715bb..2e282d77 100644 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py @@ -1,38 +1,49 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.host.computer import Computer def test_remote_logon(game_and_agent): """Test that the remote session login action works.""" game, agent = game_and_agent + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) + action = ( "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0}, + {"node_id": 0, "username": "test_user", "password": "password"}, ) agent.store_action(action) game.step() - # TODO Assert that there is a logged in user + assert len(client_1.user_session_manager.remote_sessions) == 1 def test_remote_logoff(game_and_agent): """Test that the remote session logout action works.""" game, agent = game_and_agent + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) + action = ( "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0}, + {"node_id": 0, "username": "test_user", "password": "password"}, ) agent.store_action(action) game.step() - # TODO Assert that there is a logged in user + assert len(client_1.user_session_manager.remote_sessions) == 1 + + remote_session_id = client_1.user_session_manager.remote_sessions[0].uuid action = ( "NODE_SESSIONS_REMOTE_LOGOUT", - {"node_id": 0}, + {"node_id": 0, "remote_session_id": remote_session_id}, ) agent.store_action(action) game.step() - # TODO Assert the user has logged out + assert len(client_1.user_session_manager.remote_sessions) == 0 diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py index 27328100..3e6f55f6 100644 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py @@ -1,13 +1,23 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + def test_remote_logon(game_and_agent): """Test that the remote session login action works.""" game, agent = game_and_agent + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) + user = next((user for user in client_1.user_manager.users.values() if user.username == "test_user"), None) + + assert user.password == "password" + action = ( "NODE_ACCOUNTS_CHANGEPASSWORD", - {"node_id": 0}, + {"node_id": 0, "username": user.username, "current_password": user.password, "new_password": "test_pass"}, ) agent.store_action(action) game.step() - # TODO Assert that the user account password is changed + assert user.password == "test_pass" From b4893c44989ba498e31ca1e47fcd94685e0f9301 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 5 Aug 2024 16:27:53 +0100 Subject: [PATCH 04/12] #2769 - Add remote ip as action parameter --- src/primaite/game/agent/actions.py | 33 ++++++++++++++++--- .../test_remote_user_account_actions.py | 2 +- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 7c908f42..2ddeff3d 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1079,7 +1079,18 @@ class NodeAccountsChangePasswordAction(AbstractAction): def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "accounts", "change_password", username, current_password, new_password] + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserManager", + "change_password", + username, + current_password, + new_password, + ] class NodeSessionsRemoteLoginAction(AbstractAction): @@ -1088,9 +1099,21 @@ class NodeSessionsRemoteLoginAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str, username: str, password: str) -> RequestFormat: + def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "sessions", "remote_login", username, password] + # TODO: change this so it creates a remote connection using terminal rather than a local remote login + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserSessionManager", + "remote_login", + username, + password, + remote_ip, + ] class NodeSessionsRemoteLogoutAction(AbstractAction): @@ -1101,7 +1124,9 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): def form_request(self, node_id: str, remote_session_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "sessions", "remote_logout", remote_session_id] + # TODO: change this so it destroys a remote connection using terminal rather than a local remote login + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "UserSessionManager", "remote_logout", remote_session_id] class ActionManager: diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py index 2e282d77..25079226 100644 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py @@ -12,7 +12,7 @@ def test_remote_logon(game_and_agent): action = ( "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0, "username": "test_user", "password": "password"}, + {"node_id": 0, "username": "test_user", "password": "password", "remote_ip": "10.0.2.2"}, ) agent.store_action(action) game.step() From 3df55a708d31f192a8a414673dce3e23e9126486 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 11 Aug 2024 23:24:29 +0100 Subject: [PATCH 05/12] #2769 - add actions and tests for terminal --- src/primaite/game/agent/actions.py | 28 ++- .../system/services/terminal/terminal.py | 120 +++++++------ tests/conftest.py | 7 +- .../actions/test_terminal_actions.py | 165 ++++++++++++++++++ .../test_remote_user_account_actions.py | 49 ------ .../test_user_account_change_password.py | 23 --- 6 files changed, 253 insertions(+), 139 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/test_terminal_actions.py delete mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py delete mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 2ddeff3d..f421cb0b 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1101,15 +1101,14 @@ class NodeSessionsRemoteLoginAction(AbstractAction): def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # TODO: change this so it creates a remote connection using terminal rather than a local remote login node_name = self.manager.get_node_name_by_idx(node_id) return [ "network", "node", node_name, "service", - "UserSessionManager", - "remote_login", + "Terminal", + "ssh_to_remote", username, password, remote_ip, @@ -1122,11 +1121,21 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str, remote_session_id: str) -> RequestFormat: + def form_request(self, node_id: str, remote_ip: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # TODO: change this so it destroys a remote connection using terminal rather than a local remote login node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "UserSessionManager", "remote_logout", remote_session_id] + return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip] + + +class NodeSendRemoteCommandAction(AbstractAction): + """Action which sends a terminal command to a remote node via SSH.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "Terminal", "send_remote_command", remote_ip, command] class ActionManager: @@ -1180,9 +1189,10 @@ class ActionManager: "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, "CONFIGURE_DOSBOT": ConfigureDoSBotAction, - "NODE_ACCOUNTS_CHANGEPASSWORD": NodeAccountsChangePasswordAction, - "NODE_SESSIONS_REMOTE_LOGIN": NodeSessionsRemoteLoginAction, - "NODE_SESSIONS_REMOTE_LOGOUT": NodeSessionsRemoteLogoutAction, + "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, + "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, + "SSH_LOGOUT_LOGOUT": NodeSessionsRemoteLogoutAction, + "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 876b1694..ead5c66a 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -92,7 +92,7 @@ class LocalTerminalConnection(TerminalClientConnection): if not self.is_active: self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") return None - return self.parent_terminal.execute(command, connection_id=self.connection_uuid) + return self.parent_terminal.execute(command) class RemoteTerminalConnection(TerminalClientConnection): @@ -162,22 +162,36 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" rm = super()._init_request_manager() - rm.add_request( - "send", - request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), - ) + # rm.add_request( + # "send", + # request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), + # ) - def _login(request: RequestFormat, context: Dict) -> RequestResponse: - login = self._process_local_login(username=request[0], password=request[1]) - if login: - return RequestResponse( - status="success", - data={ - "ip_address": login.ip_address, - }, - ) - else: - return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) + # def _login(request: RequestFormat, context: Dict) -> RequestResponse: + # login = self._process_local_login(username=request[0], password=request[1]) + # if login: + # return RequestResponse( + # status="success", + # data={ + # "ip_address": login.ip_address, + # }, + # ) + # else: + # return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) + # + # rm.add_request( + # "Login", + # request_type=RequestType(func=_login), + # ) + + # def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: + # """Logoff from connection.""" + # connection_uuid = request[0] + # self.parent.user_session_manager.local_logout(connection_uuid) + # self._disconnect(connection_uuid) + # return RequestResponse(status="success", data={}) + # + # rm.add_request("Logoff", request_type=RequestType(func=_logoff)) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) @@ -191,10 +205,34 @@ class Terminal(Service): else: return RequestResponse(status="failure", data={}) + rm.add_request( + "ssh_to_remote", + request_type=RequestType(func=_remote_login), + ) + + def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse: + """Logoff from remote connection.""" + ip_address = IPv4Address(request[0]) + remote_connection = self._get_connection_from_ip(ip_address=ip_address) + if remote_connection: + outcome = self._disconnect(remote_connection.connection_uuid) + if outcome: + return RequestResponse( + status="success", + data={}, + ) + else: + return RequestResponse( + status="failure", + data={"reason": "No remote connection held."}, + ) + + rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff)) + def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" - command: str = request[0] - ip_address: IPv4Address = IPv4Address(request[1]) + ip_address: IPv4Address = IPv4Address(request[0]) + command: str = request[1] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: outcome = remote_connection.execute(command) @@ -209,30 +247,11 @@ class Terminal(Service): data={}, ) - def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: - """Logoff from connection.""" - connection_uuid = request[0] - self.parent.user_session_manager.local_logout(connection_uuid) - self._disconnect(connection_uuid) - return RequestResponse(status="success", data={}) - rm.add_request( - "Login", - request_type=RequestType(func=_login), - ) - - rm.add_request( - "Remote Login", - request_type=RequestType(func=_remote_login), - ) - - rm.add_request( - "Execute", + "send_remote_command", request_type=RequestType(func=remote_execute_request), ) - rm.add_request("Logoff", request_type=RequestType(func=_logoff)) - return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: @@ -280,13 +299,9 @@ class Terminal(Service): if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None - connection_request_id = str(uuid4()) - self._client_connection_requests[connection_request_id] = None if ip_address: # Assuming that if IP is passed we are connecting to remote - return self._send_remote_login( - username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id - ) + return self._send_remote_login(username=username, password=password, ip_address=ip_address) else: return self._process_local_login(username=username, password=password) @@ -320,32 +335,24 @@ class Terminal(Service): username: str, password: str, ip_address: IPv4Address, - connection_request_id: str, + connection_request_id: Optional[str] = None, is_reattempt: bool = False, ) -> Optional[RemoteTerminalConnection]: """Send a remote login attempt and connect to Node. :param: username: Username used to connect to the remote node. :type: username: str - :param: password: Password used to connect to the remote node :type: password: str - :param: ip_address: Target Node IP address for login attempt. :type: ip_address: IPv4Address - - :param: connection_request_id: Connection Request ID - :type: connection_request_id: str - + :param: connection_request_id: Connection Request ID, if not provided, a new one is generated + :type: connection_request_id: Optional[str] :param: is_reattempt: True if the request has been reattempted. Default False. :type: is_reattempt: Optional[bool] - :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. - """ - self.sys_log.info( - f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" - ) + connection_request_id = connection_request_id or str(uuid4()) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -360,6 +367,9 @@ class Terminal(Service): self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") return None + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password) diff --git a/tests/conftest.py b/tests/conftest.py index d2f9bb2f..2ae6299d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -458,9 +458,10 @@ def game_and_agent(): {"type": "HOST_NIC_DISABLE"}, {"type": "NETWORK_PORT_ENABLE"}, {"type": "NETWORK_PORT_DISABLE"}, - {"type": "NODE_ACCOUNTS_CHANGEPASSWORD"}, - {"type": "NODE_SESSIONS_REMOTE_LOGIN"}, - {"type": "NODE_SESSIONS_REMOTE_LOGOUT"}, + {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, + {"type": "SSH_TO_REMOTE"}, + {"type": "SSH_LOGOUT_LOGOUT"}, + {"type": "NODE_SEND_REMOTE_COMMAND"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py new file mode 100644 index 00000000..ce0810eb --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -0,0 +1,165 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import UserManager +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if not connection_established: + pytest.fail("Remote SSH connection could not be established") + + +def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "wrong_password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "failure" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if connection_established: + pytest.fail("Remote SSH connection was established despite wrong password") + + +def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_um: UserManager = server_1.software_manager.software["UserManager"] + server_1_um.add_user("user123", "password", is_admin=True) + + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + assert server_1_um.users["user123"].password == "different_password" + + +def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + # Log in remotely + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + + # Change password + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + # Assert that the user cannot execute an action + # TODO: should the db conn object get destroyed on both nodes? or is that not realistic? + action = ( + "NODE_SEND_REMOTE_COMMAND", + { + "node_id": 0, + "remote_ip": server_1.network_interface[1].ip_address, + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert server_1.file_system.get_folder("folder123") is None + assert server_1.file_system.get_file("folder123", "doggo.pdf") is None diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py deleted file mode 100644 index 25079226..00000000 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py +++ /dev/null @@ -1,49 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.host.computer import Computer - - -def test_remote_logon(game_and_agent): - """Test that the remote session login action works.""" - game, agent = game_and_agent - - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - - client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) - - action = ( - "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0, "username": "test_user", "password": "password", "remote_ip": "10.0.2.2"}, - ) - agent.store_action(action) - game.step() - - assert len(client_1.user_session_manager.remote_sessions) == 1 - - -def test_remote_logoff(game_and_agent): - """Test that the remote session logout action works.""" - game, agent = game_and_agent - - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - - client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) - - action = ( - "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0, "username": "test_user", "password": "password"}, - ) - agent.store_action(action) - game.step() - - assert len(client_1.user_session_manager.remote_sessions) == 1 - - remote_session_id = client_1.user_session_manager.remote_sessions[0].uuid - - action = ( - "NODE_SESSIONS_REMOTE_LOGOUT", - {"node_id": 0, "remote_session_id": remote_session_id}, - ) - agent.store_action(action) - game.step() - - assert len(client_1.user_session_manager.remote_sessions) == 0 diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py deleted file mode 100644 index 3e6f55f6..00000000 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py +++ /dev/null @@ -1,23 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.host.computer import Computer - - -def test_remote_logon(game_and_agent): - """Test that the remote session login action works.""" - game, agent = game_and_agent - - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - - client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) - user = next((user for user in client_1.user_manager.users.values() if user.username == "test_user"), None) - - assert user.password == "password" - - action = ( - "NODE_ACCOUNTS_CHANGEPASSWORD", - {"node_id": 0, "username": user.username, "current_password": user.password, "new_password": "test_pass"}, - ) - agent.store_action(action) - game.step() - - assert user.password == "test_pass" From 929bd46d6dea2e53d292a26ec765bdb06908d792 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 12 Aug 2024 14:16:04 +0100 Subject: [PATCH 06/12] #2769 - Make changing password disconnect remote sessions --- src/primaite/game/agent/actions.py | 12 +++++++++++- .../simulator/network/hardware/base.py | 18 ++++++++++++++++++ .../system/services/terminal/terminal.py | 7 ++++++- .../actions/test_terminal_actions.py | 8 +++++--- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index f421cb0b..d588c018 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1134,8 +1134,18 @@ class NodeSendRemoteCommandAction(AbstractAction): super().__init__(manager=manager) def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "Terminal", "send_remote_command", remote_ip, command] + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "send_remote_command", + remote_ip, + {"command": command}, + ] class ActionManager: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 1441c93b..68b45c2e 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -990,6 +990,7 @@ class UserManager(Service): if user and user.password == current_password: user.password = new_password self.sys_log.info(f"{self.name}: Password changed for {username}") + self._user_session_manager._logout_user(user=user) return True self.sys_log.info(f"{self.name}: Password change failed for {username}") return False @@ -1027,6 +1028,10 @@ class UserManager(Service): self.sys_log.info(f"{self.name}: Failed to enable user: {username}") return False + @property + def _user_session_manager(self) -> "UserSessionManager": + return self.software_manager.software["UserSessionManager"] # noqa + class UserSession(SimComponent): """ @@ -1435,6 +1440,19 @@ class UserSessionManager(Service): """ return self._logout(local=False, remote_session_id=remote_session_id) + def _logout_user(self, user: Union[str, User]) -> bool: + """End a user session by username or user object.""" + if isinstance(user, str): + user = self._user_manager.users[user] # grab user object from username + for sess_id, session in self.remote_sessions.items(): + if session.user is user: + self._logout(local=False, remote_session_id=sess_id) + return True + if self.local_user_logged_in and self.local_session.user is user: + self.local_logout() + return True + return False + @property def local_user_logged_in(self) -> bool: """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index ead5c66a..79dc698f 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor +# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. @@ -232,7 +234,7 @@ class Terminal(Service): def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" ip_address: IPv4Address = IPv4Address(request[0]) - command: str = request[1] + command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: outcome = remote_connection.execute(command) @@ -328,6 +330,9 @@ class Terminal(Service): def _check_client_connection(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" + if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id): + self._disconnect(connection_id) + return False return connection_id in self._connections def _send_remote_login( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index ce0810eb..84d21bb0 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -8,6 +8,8 @@ from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection @@ -17,8 +19,8 @@ def game_and_agent_fixture(game_and_agent): """Create a game with a simple agent that can be controlled by the tests.""" game, agent = game_and_agent - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - client_1.start_up_duration = 3 + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) return (game, agent) @@ -154,7 +156,7 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam "NODE_SEND_REMOTE_COMMAND", { "node_id": 0, - "remote_ip": server_1.network_interface[1].ip_address, + "remote_ip": str(server_1.network_interface[1].ip_address), "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], }, ) From 1d2705eb1b95bb369926bc5858feef7a56cf4abe Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 15 Aug 2024 20:16:11 +0100 Subject: [PATCH 07/12] #2769 - Add user login observations --- .../observations/firewall_observation.py | 20 +- .../agent/observations/host_observations.py | 21 + .../agent/observations/node_observations.py | 8 + .../agent/observations/router_observation.py | 17 +- .../simulator/network/hardware/base.py | 3 +- tests/assets/configs/data_manipulation.yaml | 942 ++++++++++++++++++ .../observations/test_user_observations.py | 89 ++ 7 files changed, 1096 insertions(+), 4 deletions(-) create mode 100644 tests/assets/configs/data_manipulation.yaml create mode 100644 tests/integration_tests/game_layer/observations/test_user_observations.py diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 4f1a9d90..42ceaff0 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -10,6 +10,7 @@ from primaite import getLogger from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list: List[int], protocol_list: List[str], num_rules: int, + include_users: bool, ) -> None: """ Initialise a firewall observation instance. @@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :type protocol_list: List[str] :param num_rules: Number of rules configured in the firewall. :type num_rules: int + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" self.ports: List[PortObservation] = [ PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] @@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. :rtype: ObsType """ + firewall_state = access_from_nested_dict(state, self.where) + if firewall_state is NOT_PRESENT_IN_STATE: + return self.default_observation obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, "ACL": { @@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + sess = firewall_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 7053d019..4419ccc7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -52,6 +52,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -69,6 +71,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + include_users: bool, ) -> None: """ Initialise a host observation instance. @@ -103,10 +106,15 @@ class HostObservation(AbstractObservation, identifier="HOST"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.include_num_access = include_num_access + self.include_users = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services @@ -165,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: self.default_observation["num_file_creations"] = 0 self.default_observation["num_file_deletions"] = 0 + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -192,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -216,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: shape["num_file_creations"] = spaces.Discrete(4) shape["num_file_deletions"] = spaces.Discrete(4) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod @@ -273,4 +293,5 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index c68531f8..e263cadb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -46,6 +46,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Flag to include the number of accesses.""" file_system_requires_scan: bool = True """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + include_users: Optional[bool] = True + """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" ip_list: Optional[List[str]] = None @@ -191,6 +193,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.include_users is None: + host_config.include_users = config.include_users for router_config in config.routers: if router_config.num_ports is None: @@ -205,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.protocol_list = config.protocol_list if router_config.num_rules is None: router_config.num_rules = config.num_rules + if router_config.include_users is None: + router_config.include_users = config.include_users for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -217,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.protocol_list = config.protocol_list if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules + if firewall_config.include_users is None: + firewall_config.include_users = config.include_users hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index f1d4ec8e..d064936a 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports: List[PortObservation], num_ports: int, acl: ACLObservation, + include_users: bool, ) -> None: """ Initialise a router observation instance. @@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :type num_ports: int :param acl: ACL observation representing the access control list of the router. :type acl: ACLObservation + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.ports: List[PortObservation] = ports self.acl: ACLObservation = acl self.num_ports: int = num_ports - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" while len(self.ports) < num_ports: self.ports.append(PortObservation(where=None)) while len(self.ports) > num_ports: @@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): obs["ACL"] = self.acl.observe(state) if self.ports: obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 68b45c2e..b0c48e7d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1265,7 +1265,8 @@ class UserSessionManager(Service): :return: A dictionary representing the current state. """ state = super().describe_state() - state["active_remote_logins"] = len(self.remote_sessions) + state["current_local_user"] = None if not self.local_session else self.local_session.user.username + state["active_remote_sessions"] = list(self.remote_sessions.keys()) return state @property diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml new file mode 100644 index 00000000..97442903 --- /dev/null +++ b/tests/assets/configs/data_manipulation.yaml @@ -0,0 +1,942 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: switch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py new file mode 100644 index 00000000..ca5e2543 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -0,0 +1,89 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +@pytest.fixture +def env_with_ssh() -> PrimaiteGymEnv: + """Build data manipulation environment with SSH port open on router.""" + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.agent.flatten_obs = False + router: Router = env.game.simulation.network.get_node_by_hostname("router_1") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + return env + + +def extract_login_numbers_from_obs(obs): + """Traverse the observation dictionary and return number of user sessions for all nodes.""" + login_nums = {} + for node_name, node_obs in obs["NODES"].items(): + login_nums[node_name] = node_obs.get("users") + return login_nums + + +class TestUserObservations: + """Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins.""" + + def test_no_sessions_at_episode_start(self, env_with_ssh): + """Test that all of the login observations start at 0 before any logins occur.""" + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_single_login(self, env_with_ssh: PrimaiteGymEnv): + """Test that performing a remote login increases the remote_sessions observation by 1.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + assert db_srv_logins_obs["local_login"] == 0 + assert db_srv_logins_obs["remote_sessions"] == 1 + for o in logins_obs.values(): # the remaining obs after popping HOST2 + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_logout(self, env_with_ssh: PrimaiteGymEnv): + """Test that remote_sessions observation correctly decreases upon logout.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv): + """Log in from 5 remote places and check that only a max of 3 is shown in the observation.""" + MAX_OBSERVABLE_SESSIONS = 3 + # Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_session_manager.remote_session_timeout_steps = 20 + db_srv.user_session_manager.max_remote_sessions = 5 + node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller") + + for i, node_name in enumerate(node_names): + node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name) + node.terminal._send_remote_login("admin", "admin", "192.168.1.14") + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1) + assert len(db_srv.user_session_manager.remote_sessions) == i + 1 From 21c0b02ff79ed4c469baa2874be543cfb8e94fc0 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 16 Aug 2024 09:21:27 +0100 Subject: [PATCH 08/12] #2769 - update observation tests with new parameter --- .../game_layer/observations/test_firewall_observation.py | 1 + .../game_layer/observations/test_node_observations.py | 1 + .../game_layer/observations/test_router_observation.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 99417e33..34a37f5e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -33,6 +33,7 @@ def test_firewall_observation(): wildcard_list=["0.0.0.255", "0.0.0.1"], port_list=["HTTP", "DNS"], protocol_list=["TCP"], + include_users=False, ) observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 1edb0442..69d9f106 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,7 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + include_users=False, ) assert host_obs.space["operating_status"] == spaces.Discrete(5) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index c534307f..48d29cfb 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -27,7 +27,7 @@ def test_router_observation(): port_list=["HTTP", "DNS"], protocol_list=["TCP"], ) - router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl) + router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) From d74227e34f663799ef28ccb71062ebf690eb76ca Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 16 Aug 2024 10:10:26 +0100 Subject: [PATCH 09/12] #2769 - update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c63b114..8ac61df4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. +- Added actions to establish SSH connections, send commands remotely and terminate SSH connections. +- Added actions to change users' passwords. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. +- Node observations can now be configured to show the number of active local and remote logins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) From aeca5fb6a27efb77dd588e78f6815d056c5db8da Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 10:28:39 +0100 Subject: [PATCH 10/12] #2769 - Clean up incorrect names and commented out code [skip ci] --- src/primaite/game/agent/actions.py | 2 +- .../system/services/terminal/terminal.py | 30 ------------------- tests/conftest.py | 2 +- 3 files changed, 2 insertions(+), 32 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index d588c018..2a0c5351 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1201,7 +1201,7 @@ class ActionManager: "CONFIGURE_DOSBOT": ConfigureDoSBotAction, "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, - "SSH_LOGOUT_LOGOUT": NodeSessionsRemoteLogoutAction, + "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 79dc698f..406facd1 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -164,36 +164,6 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" rm = super()._init_request_manager() - # rm.add_request( - # "send", - # request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), - # ) - - # def _login(request: RequestFormat, context: Dict) -> RequestResponse: - # login = self._process_local_login(username=request[0], password=request[1]) - # if login: - # return RequestResponse( - # status="success", - # data={ - # "ip_address": login.ip_address, - # }, - # ) - # else: - # return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) - # - # rm.add_request( - # "Login", - # request_type=RequestType(func=_login), - # ) - - # def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: - # """Logoff from connection.""" - # connection_uuid = request[0] - # self.parent.user_session_manager.local_logout(connection_uuid) - # self._disconnect(connection_uuid) - # return RequestResponse(status="success", data={}) - # - # rm.add_request("Logoff", request_type=RequestType(func=_logoff)) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) diff --git a/tests/conftest.py b/tests/conftest.py index 2ae6299d..abc851c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -460,7 +460,7 @@ def game_and_agent(): {"type": "NETWORK_PORT_DISABLE"}, {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, {"type": "SSH_TO_REMOTE"}, - {"type": "SSH_LOGOUT_LOGOUT"}, + {"type": "SESSIONS_REMOTE_LOGOFF"}, {"type": "NODE_SEND_REMOTE_COMMAND"}, ] From a997cebbc69ce0196e7843ec1349057275991f55 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 11:14:53 +0000 Subject: [PATCH 11/12] Apply suggestions from code review [skip ci] --- .../game_layer/actions/test_terminal_actions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index 84d21bb0..d011c1e8 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -151,7 +151,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam game.step() # Assert that the user cannot execute an action - # TODO: should the db conn object get destroyed on both nodes? or is that not realistic? action = ( "NODE_SEND_REMOTE_COMMAND", { From 2c71958c913d4186474ea2896c30bb23f56c888c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 19 Aug 2024 12:55:45 +0100 Subject: [PATCH 12/12] #2748: Port of PrimAITE Internal changes. --- CHANGELOG.md | 1 + src/primaite/game/agent/interface.py | 2 + src/primaite/game/agent/rewards.py | 61 +++++++++++++++---- .../game_layer/test_rewards.py | 21 +++---- 4 files changed, 63 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c63b114..7daf1f60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. +- Added reward calculation details to AgentHistoryItem. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index f57dc191..14b97821 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel): reward: Optional[float] = None + reward_info: Dict[str, Any] = {} + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c959ee5b..b913501d 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,7 +47,15 @@ class AbstractReward: @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -67,7 +75,15 @@ class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward): def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. - :param state: The current state of the simulation. + :param state: Current simulation state :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: @@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last successful request returned was able to connect to the database server, because there has been an unsuccessful request since. + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: self._last_request_failed = last_action_response.response.status != "success" @@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): # If the last request was actually sent, then check if the connection was established. db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: - _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"} return 0.0 last_connection_successful = db_state["last_connection_successful"] - if last_connection_successful is False: - return -1.0 - elif last_connection_successful is True: - return 1.0 - return 0.0 + last_action_response.reward_info = {"last_connection_successful": last_connection_successful} + return 1.0 if last_connection_successful else -1.0 @classmethod def from_config(cls, config: Dict) -> AbstractReward: @@ -346,7 +369,15 @@ class SharedReward(AbstractReward): """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Simply access the other agent's reward and return it.""" + """Simply access the other agent's reward and return it. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return self.callback(self.agent_name) @classmethod @@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward): self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the penalty to be applied.""" + """Calculate the penalty to be applied. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ if last_action_response.action == "DONOTHING": return self.do_nothing_penalty else: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 2bf551c8..e945f482 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), + ahi = AgentHistoryItem( + timestep=0, + action="NODE_APPLICATION_EXECUTE", + parameters={}, + request=["execute"], + response=response, ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == 1.0 + assert ahi.reward_info == {"last_connection_successful": True} router.acl.remove_rule(position=2) @@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), - ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == -1.0 + assert ahi.reward_info == {"last_connection_successful": False} def test_shared_reward():