From 173f110fb248911c7e5748e585e67c0fbbc04b15 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 24 Jul 2024 16:38:06 +0100 Subject: [PATCH 01/18] #2769: initial commit of user account actions --- pyproject.toml | 2 +- src/primaite/game/agent/actions.py | 36 ++++++++++++++++++ .../simulator/network/hardware/base.py | 9 +++++ tests/conftest.py | 3 ++ .../test_remote_user_account_actions.py | 38 +++++++++++++++++++ 5 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py diff --git a/pyproject.toml b/pyproject.toml index 9e919604..f63ee4c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ license-files = ["LICENSE"] [project.optional-dependencies] rl = [ - "ray[rllib] >= 2.20.0, < 3", + "ray[rllib] >= 2.32.0, < 3", "tensorflow==2.12.0", "stable-baselines3[extra]==2.1.0", "sb3-contrib==2.1.0", diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 9a5fedc9..bf8e4323 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1071,6 +1071,39 @@ class NodeNetworkServiceReconAction(AbstractAction): ] +class NodeAccountsChangePasswordAction(AbstractAction): + """Action which changes the password for a user.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + pass + + +class NodeSessionsRemoteLoginAction(AbstractAction): + """Action which performs a remote session login.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + pass + + def form_request(self, node_id: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", node_id, "remote_logon"] + + +class NodeSessionsRemoteLogoutAction(AbstractAction): + """Action which performs a remote session logout.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + pass + + def form_request(self, node_id: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return ["network", "node", node_id, "remote_logoff"] + + class ActionManager: """Class which manages the action space for an agent.""" @@ -1122,6 +1155,9 @@ class ActionManager: "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, "CONFIGURE_DOSBOT": ConfigureDoSBotAction, + "NODE_ACCOUNTS_CHANGEPASSWORD": NodeAccountsChangePasswordAction, + "NODE_SESSIONS_REMOTE_LOGIN": NodeSessionsRemoteLoginAction, + "NODE_SESSIONS_REMOTE_LOGOUT": NodeSessionsRemoteLogoutAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 15c44821..831a8539 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1072,6 +1072,15 @@ class Node(SimComponent): "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request + rm.add_request( + "remote_logon", + RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), + ) # TODO implement remote_logon request + rm.add_request( + "remote_logoff", + RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), + ) # TODO implement remote_logoff request + self._os_request_manager = RequestManager() self._os_request_manager.add_request( "scan", diff --git a/tests/conftest.py b/tests/conftest.py index 54519e2b..e1ce41b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -458,6 +458,9 @@ def game_and_agent(): {"type": "HOST_NIC_DISABLE"}, {"type": "NETWORK_PORT_ENABLE"}, {"type": "NETWORK_PORT_DISABLE"}, + {"type": "NODE_ACCOUNTS_CHANGEPASSWORD"}, + {"type": "NODE_SESSIONS_REMOTE_LOGIN"}, + {"type": "NODE_SESSIONS_REMOTE_LOGOUT"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py new file mode 100644 index 00000000..807715bb --- /dev/null +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py @@ -0,0 +1,38 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK + + +def test_remote_logon(game_and_agent): + """Test that the remote session login action works.""" + game, agent = game_and_agent + + action = ( + "NODE_SESSIONS_REMOTE_LOGIN", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert that there is a logged in user + + +def test_remote_logoff(game_and_agent): + """Test that the remote session logout action works.""" + game, agent = game_and_agent + + action = ( + "NODE_SESSIONS_REMOTE_LOGIN", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert that there is a logged in user + + action = ( + "NODE_SESSIONS_REMOTE_LOGOUT", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert the user has logged out From df50ec8abc65d08961cc49716925e6296f68b787 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 25 Jul 2024 10:02:32 +0100 Subject: [PATCH 02/18] #2769: add change password action --- src/primaite/game/agent/actions.py | 8 ++++---- src/primaite/simulator/network/hardware/base.py | 5 ++++- .../test_user_account_change_password.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index bf8e4323..266c667b 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1077,16 +1077,16 @@ class NodeAccountsChangePasswordAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self) -> RequestFormat: + def form_request(self, node_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - pass + return ["network", "node", node_id, "change_password"] class NodeSessionsRemoteLoginAction(AbstractAction): """Action which performs a remote session login.""" def __init__(self, manager: "ActionManager", **kwargs) -> None: - pass + super().__init__(manager=manager) def form_request(self, node_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" @@ -1097,7 +1097,7 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): """Action which performs a remote session logout.""" def __init__(self, manager: "ActionManager", **kwargs) -> None: - pass + super().__init__(manager=manager) def form_request(self, node_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 831a8539..3ef33ac3 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1071,7 +1071,10 @@ class Node(SimComponent): rm.add_request( "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request - + rm.add_request( + "change_password", + RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), + ) # TODO implement change_password request rm.add_request( "remote_logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py new file mode 100644 index 00000000..27328100 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py @@ -0,0 +1,13 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +def test_remote_logon(game_and_agent): + """Test that the remote session login action works.""" + game, agent = game_and_agent + + action = ( + "NODE_ACCOUNTS_CHANGEPASSWORD", + {"node_id": 0}, + ) + agent.store_action(action) + game.step() + + # TODO Assert that the user account password is changed From 7b523d9450a3e7fef252458d4fbe0fe9e0f4928c Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 30 Jul 2024 11:33:52 +0100 Subject: [PATCH 03/18] #2769: added changes which should align with 2735 once merged --- src/primaite/game/agent/actions.py | 12 +++++----- .../simulator/network/hardware/base.py | 12 ---------- .../test_remote_user_account_actions.py | 23 ++++++++++++++----- .../test_user_account_change_password.py | 14 +++++++++-- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 266c667b..19442818 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1077,9 +1077,9 @@ class NodeAccountsChangePasswordAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str) -> RequestFormat: + def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "change_password"] + return ["network", "node", node_id, "accounts", "change_password", username, current_password, new_password] class NodeSessionsRemoteLoginAction(AbstractAction): @@ -1088,9 +1088,9 @@ class NodeSessionsRemoteLoginAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str) -> RequestFormat: + def form_request(self, node_id: str, username: str, password: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "remote_logon"] + return ["network", "node", node_id, "sessions", "remote_login", username, password] class NodeSessionsRemoteLogoutAction(AbstractAction): @@ -1099,9 +1099,9 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str) -> RequestFormat: + def form_request(self, node_id: str, remote_session_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "remote_logoff"] + return ["network", "node", node_id, "sessions", "remote_logout", remote_session_id] class ActionManager: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 3ef33ac3..15c44821 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1071,18 +1071,6 @@ class Node(SimComponent): rm.add_request( "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) ) # TODO implement logoff request - rm.add_request( - "change_password", - RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), - ) # TODO implement change_password request - rm.add_request( - "remote_logon", - RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), - ) # TODO implement remote_logon request - rm.add_request( - "remote_logoff", - RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on), - ) # TODO implement remote_logoff request self._os_request_manager = RequestManager() self._os_request_manager.add_request( diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py index 807715bb..2e282d77 100644 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py @@ -1,38 +1,49 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.host.computer import Computer def test_remote_logon(game_and_agent): """Test that the remote session login action works.""" game, agent = game_and_agent + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) + action = ( "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0}, + {"node_id": 0, "username": "test_user", "password": "password"}, ) agent.store_action(action) game.step() - # TODO Assert that there is a logged in user + assert len(client_1.user_session_manager.remote_sessions) == 1 def test_remote_logoff(game_and_agent): """Test that the remote session logout action works.""" game, agent = game_and_agent + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) + action = ( "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0}, + {"node_id": 0, "username": "test_user", "password": "password"}, ) agent.store_action(action) game.step() - # TODO Assert that there is a logged in user + assert len(client_1.user_session_manager.remote_sessions) == 1 + + remote_session_id = client_1.user_session_manager.remote_sessions[0].uuid action = ( "NODE_SESSIONS_REMOTE_LOGOUT", - {"node_id": 0}, + {"node_id": 0, "remote_session_id": remote_session_id}, ) agent.store_action(action) game.step() - # TODO Assert the user has logged out + assert len(client_1.user_session_manager.remote_sessions) == 0 diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py index 27328100..3e6f55f6 100644 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py @@ -1,13 +1,23 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + def test_remote_logon(game_and_agent): """Test that the remote session login action works.""" game, agent = game_and_agent + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) + user = next((user for user in client_1.user_manager.users.values() if user.username == "test_user"), None) + + assert user.password == "password" + action = ( "NODE_ACCOUNTS_CHANGEPASSWORD", - {"node_id": 0}, + {"node_id": 0, "username": user.username, "current_password": user.password, "new_password": "test_pass"}, ) agent.store_action(action) game.step() - # TODO Assert that the user account password is changed + assert user.password == "test_pass" From b4893c44989ba498e31ca1e47fcd94685e0f9301 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 5 Aug 2024 16:27:53 +0100 Subject: [PATCH 04/18] #2769 - Add remote ip as action parameter --- src/primaite/game/agent/actions.py | 33 ++++++++++++++++--- .../test_remote_user_account_actions.py | 2 +- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 7c908f42..2ddeff3d 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1079,7 +1079,18 @@ class NodeAccountsChangePasswordAction(AbstractAction): def form_request(self, node_id: str, username: str, current_password: str, new_password: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "accounts", "change_password", username, current_password, new_password] + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserManager", + "change_password", + username, + current_password, + new_password, + ] class NodeSessionsRemoteLoginAction(AbstractAction): @@ -1088,9 +1099,21 @@ class NodeSessionsRemoteLoginAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str, username: str, password: str) -> RequestFormat: + def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "sessions", "remote_login", username, password] + # TODO: change this so it creates a remote connection using terminal rather than a local remote login + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserSessionManager", + "remote_login", + username, + password, + remote_ip, + ] class NodeSessionsRemoteLogoutAction(AbstractAction): @@ -1101,7 +1124,9 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): def form_request(self, node_id: str, remote_session_id: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", node_id, "sessions", "remote_logout", remote_session_id] + # TODO: change this so it destroys a remote connection using terminal rather than a local remote login + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "UserSessionManager", "remote_logout", remote_session_id] class ActionManager: diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py index 2e282d77..25079226 100644 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py @@ -12,7 +12,7 @@ def test_remote_logon(game_and_agent): action = ( "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0, "username": "test_user", "password": "password"}, + {"node_id": 0, "username": "test_user", "password": "password", "remote_ip": "10.0.2.2"}, ) agent.store_action(action) game.step() From 368e846c8b59488746e56610727fd3d99bc54090 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Aug 2024 10:07:19 +0100 Subject: [PATCH 05/18] 2772 - Generate pdf benchmark from --- benchmark/primaite_benchmark.py | 15 ++++++++++- benchmark/report.py | 47 +++++++++++++++++++++++++-------- benchmark/static/styles.css | 34 ++++++++++++++++++++++++ pyproject.toml | 3 ++- 4 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 benchmark/static/styles.css diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 0e6c2acc..2b09870d 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -5,7 +5,7 @@ from datetime import datetime from pathlib import Path from typing import Any, Dict, Final, Tuple -from report import build_benchmark_md_report +from report import build_benchmark_md_report, md2pdf from stable_baselines3 import PPO import primaite @@ -159,6 +159,13 @@ def run( learning_rate: float = 3e-4, ) -> None: """Run the PrimAITE benchmark.""" + # generate report folder + v_str = f"v{primaite.__version__}" + + version_result_dir = _RESULTS_ROOT / v_str + version_result_dir.mkdir(exist_ok=True, parents=True) + output_path = version_result_dir / f"PrimAITE {v_str} Benchmark Report.md" + benchmark_start_time = datetime.now() session_metadata_dict = {} @@ -193,6 +200,12 @@ def run( session_metadata=session_metadata_dict, config_path=data_manipulation_config_path(), results_root_path=_RESULTS_ROOT, + output_path=output_path, + ) + md2pdf( + md_path=output_path, + pdf_path=str(output_path).replace(".md", ".pdf"), + css_path="benchmark/static/styles.css", ) diff --git a/benchmark/report.py b/benchmark/report.py index e1ff46b9..408e91cf 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -2,6 +2,7 @@ import json import sys from datetime import datetime +from os import PathLike from pathlib import Path from typing import Dict, Optional @@ -14,7 +15,7 @@ from utils import _get_system_info import primaite PLOT_CONFIG = { - "size": {"auto_size": False, "width": 1500, "height": 900}, + "size": {"auto_size": False, "width": 800, "height": 800}, "template": "plotly_white", "range_slider": False, } @@ -144,6 +145,20 @@ def _plot_benchmark_metadata( yaxis={"title": "Total Reward"}, title=title, ) + fig.update_layout( + legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=0.01, + bgcolor="rgba(255,255,255,0.3)", + ) + ) + for trace in fig["data"]: + if trace["name"].startswith("Session"): + trace["showlegend"] = False + fig["data"][0]["name"] = "Individual Sessions" + fig["data"][0]["showlegend"] = True return fig @@ -194,6 +209,7 @@ def _plot_all_benchmarks_combined_session_av(results_directory: Path) -> Figure: title=title, ) fig["data"][0]["showlegend"] = True + fig.update_layout(legend=dict(yanchor="top", y=-0.2, xanchor="left", x=0.01, orientation="h")) return fig @@ -248,14 +264,7 @@ def _plot_av_s_per_100_steps_10_nodes( versions = sorted(list(version_times_dict.keys())) times = [version_times_dict[version] for version in versions] - fig.add_trace( - go.Bar( - x=versions, - y=times, - text=times, - textposition="auto", - ) - ) + fig.add_trace(go.Bar(x=versions, y=times, text=times, textposition="auto", texttemplate="%{y:.3f}")) fig.update_layout( xaxis_title="PrimAITE Version", @@ -267,7 +276,11 @@ def _plot_av_s_per_100_steps_10_nodes( def build_benchmark_md_report( - benchmark_start_time: datetime, session_metadata: Dict, config_path: Path, results_root_path: Path + benchmark_start_time: datetime, + session_metadata: Dict, + config_path: Path, + results_root_path: Path, + output_path: PathLike, ) -> None: """ Generates a Markdown report for a benchmarking session, documenting performance metrics and graphs. @@ -319,7 +332,7 @@ def build_benchmark_md_report( data = benchmark_metadata_dict primaite_version = data["primaite_version"] - with open(version_result_dir / f"PrimAITE v{primaite_version} Benchmark Report.md", "w") as file: + with open(output_path, "w") as file: # Title file.write(f"# PrimAITE v{primaite_version} Learning Benchmark\n") file.write("## PrimAITE Dev Team\n") @@ -393,3 +406,15 @@ def build_benchmark_md_report( f"![Performance of Minor and Bugfix Releases for Major Version {major_v}]" f"({performance_benchmark_plot_path.name})\n" ) + + +def md2pdf(md_path: PathLike, pdf_path: PathLike, css_path: PathLike) -> None: + """Generate PDF version of Markdown report.""" + from md2pdf.core import md2pdf + + md2pdf( + pdf_file_path=pdf_path, + md_file_path=md_path, + base_url=Path(md_path).parent, + css_file_path=css_path, + ) diff --git a/benchmark/static/styles.css b/benchmark/static/styles.css new file mode 100644 index 00000000..4fbb9bd5 --- /dev/null +++ b/benchmark/static/styles.css @@ -0,0 +1,34 @@ +body { + font-family: 'Arial', sans-serif; + line-height: 1.6; + /* margin: 1cm; */ +} +h1, h2, h3, h4, h5, h6 { + font-weight: bold; + /* margin: 1em 0; */ +} +p { + /* margin: 0.5em 0; */ +} +ul, ol { + margin: 1em 0; + padding-left: 1.5em; +} +pre { + background: #f4f4f4; + padding: 0.5em; + overflow-x: auto; +} +img { + max-width: 100%; + height: auto; +} +table { + width: 100%; + border-collapse: collapse; + margin: 1em 0; +} +th, td { + padding: 0.5em; + border: 1px solid #ddd; +} diff --git a/pyproject.toml b/pyproject.toml index c9b7c062..354df8b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,8 @@ dev = [ "wheel==0.38.4", "nbsphinx==0.9.4", "nbmake==1.5.4", - "pytest-xdist==3.3.1" + "pytest-xdist==3.3.1", + "md2pdf", ] [project.scripts] From fe599f77452bc8e48d8437c77b29e2e649f0dde7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Aug 2024 12:09:44 +0100 Subject: [PATCH 06/18] #2799 - Fix folder scan not being required and make it configurable --- CHANGELOG.md | 4 + .../observations/file_system_observations.py | 56 +++++++- .../agent/observations/host_observations.py | 18 ++- .../agent/observations/node_observations.py | 4 + .../_game/_agent/test_observations.py | 132 ++++++++++++++++++ 5 files changed, 206 insertions(+), 8 deletions(-) create mode 100644 tests/unit_tests/_primaite/_game/_agent/test_observations.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d999607..73a3f496 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality. +- File and folder observations can now be configured to always show the true health status, or require scanning like before. + +### Fixed +- Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) ## [3.2.0] - 2024-07-18 diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index cb48fe7d..bd130673 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -23,8 +23,10 @@ class FileObservation(AbstractObservation, identifier="FILE"): """Name of the file, used for querying simulation state dictionary.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to the file in the observation.""" + file_system_requires_scan: Optional[bool] = None + """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool) -> None: + def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: """ Initialise a file observation instance. @@ -34,9 +36,13 @@ class FileObservation(AbstractObservation, identifier="FILE"): :type where: WhereType :param include_num_access: Whether to include the number of accesses to the file in the observation. :type include_num_access: bool + :param file_system_requires_scan: If True, the file must be scanned to update the health state. Tf False, + the true state is always shown. + :type file_system_requires_scan: bool """ self.where: WhereType = where self.include_num_access: bool = include_num_access + self.file_system_requires_scan: bool = file_system_requires_scan self.default_observation: ObsType = {"health_status": 0} if self.include_num_access: @@ -74,7 +80,11 @@ class FileObservation(AbstractObservation, identifier="FILE"): file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: return self.default_observation - obs = {"health_status": file_state["visible_status"]} + if self.file_system_requires_scan: + health_status = file_state["visible_status"] + else: + health_status = file_state["health_status"] + obs = {"health_status": health_status} if self.include_num_access: obs["num_access"] = self._categorise_num_access(file_state["num_access"]) return obs @@ -104,8 +114,15 @@ class FileObservation(AbstractObservation, identifier="FILE"): :type parent_where: WhereType, optional :return: Constructed file observation instance. :rtype: FileObservation + :param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False, + the true state is always shown. + :type file_system_requires_scan: bool """ - return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) + return cls( + where=parent_where + ["files", config.file_name], + include_num_access=config.include_num_access, + file_system_requires_scan=config.file_system_requires_scan, + ) class FolderObservation(AbstractObservation, identifier="FOLDER"): @@ -122,9 +139,16 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): """Number of spaces for file observations in this folder.""" include_num_access: Optional[bool] = None """Whether files in this folder should include the number of accesses in their observation.""" + file_system_requires_scan: Optional[bool] = None + """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" def __init__( - self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool + self, + where: WhereType, + files: Iterable[FileObservation], + num_files: int, + include_num_access: bool, + file_system_requires_scan: bool, ) -> None: """ Initialise a folder observation instance. @@ -141,9 +165,17 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): """ self.where: WhereType = where + self.file_system_requires_scan: bool = file_system_requires_scan + self.files: List[FileObservation] = files while len(self.files) < num_files: - self.files.append(FileObservation(where=None, include_num_access=include_num_access)) + self.files.append( + FileObservation( + where=None, + include_num_access=include_num_access, + file_system_requires_scan=self.file_system_requires_scan, + ) + ) while len(self.files) > num_files: truncated_file = self.files.pop() msg = f"Too many files in folder observation. Truncating file {truncated_file}" @@ -168,7 +200,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): if folder_state is NOT_PRESENT_IN_STATE: return self.default_observation - health_status = folder_state["health_status"] + if self.file_system_requires_scan: + health_status = folder_state["visible_status"] + else: + health_status = folder_state["health_status"] obs = {} @@ -209,6 +244,13 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): # pass down shared/common config items for file_config in config.files: file_config.include_num_access = config.include_num_access + file_config.file_system_requires_scan = config.file_system_requires_scan files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] - return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) + return cls( + where=where, + files=files, + num_files=config.num_files, + include_num_access=config.include_num_access, + file_system_requires_scan=config.file_system_requires_scan, + ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index f9fd9b1a..7053d019 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -48,6 +48,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Whether to include the number of accesses to files observations on this host.""" + file_system_requires_scan: Optional[bool] = None + """ + If True, files and folders must be scanned to update the health state. If False, true state is always shown. + """ def __init__( self, @@ -64,6 +68,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_nmne: bool, monitored_traffic: Optional[Dict], include_num_access: bool, + file_system_requires_scan: bool, ) -> None: """ Initialise a host observation instance. @@ -95,6 +100,9 @@ class HostObservation(AbstractObservation, identifier="HOST"): :type monitored_traffic: Dict :param include_num_access: Flag to include the number of accesses to files. :type include_num_access: bool + :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. + If False, the true state is always shown. + :type file_system_requires_scan: bool """ self.where: WhereType = where @@ -120,7 +128,13 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.folders: List[FolderObservation] = folders while len(self.folders) < num_folders: self.folders.append( - FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) + FolderObservation( + where=None, + files=[], + num_files=num_files, + include_num_access=include_num_access, + file_system_requires_scan=file_system_requires_scan, + ) ) while len(self.folders) > num_folders: truncated_folder = self.folders.pop() @@ -226,6 +240,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): for folder_config in config.folders: folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files + folder_config.file_system_requires_scan = config.file_system_requires_scan for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne @@ -257,4 +272,5 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, + file_system_requires_scan=config.file_system_requires_scan, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index f7bfcc99..c68531f8 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -44,6 +44,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """A dict containing which traffic types are to be included in the observation.""" include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" + file_system_requires_scan: bool = True + """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" num_ports: Optional[int] = None """Number of ports.""" ip_list: Optional[List[str]] = None @@ -187,6 +189,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.monitored_traffic = config.monitored_traffic if host_config.include_num_access is None: host_config.include_num_access = config.include_num_access + if host_config.file_system_requires_scan is None: + host_config.file_system_requires_scan = config.file_system_requires_scan for router_config in config.routers: if router_config.num_ports is None: diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py new file mode 100644 index 00000000..7f590685 --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -0,0 +1,132 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import List + +import pytest +import yaml + +from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.game.agent.observations.host_observations import HostObservation + + +class TestFileSystemRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("file_system_requires_scan: true", True), + ("file_system_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set FileSystemRequiresScan to True.""" + obs_cfg_yaml = f""" + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + folders: List[FolderObservation] = host.folders + for j, folder in enumerate(folders): + assert folder.file_system_requires_scan == expected_val # Make sure folders require scan by default + files: List[FileObservation] = folder.files + for k, file in enumerate(files): + assert file.file_system_requires_scan == expected_val + + def test_file_require_scan(self): + file_state = {"health_status": 3, "visible_status": 1} + + obs_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=True) + assert obs_requiring_scan.observe(file_state)["health_status"] == 1 + + obs_not_requiring_scan = FileObservation([], include_num_access=False, file_system_requires_scan=False) + assert obs_not_requiring_scan.observe(file_state)["health_status"] == 3 + + def test_folder_require_scan(self): + folder_state = {"health_status": 3, "visible_status": 1} + + obs_requiring_scan = FolderObservation( + [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True + ) + assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + + obs_not_requiring_scan = FolderObservation( + [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False + ) + assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3 From b193b46b7b725137e150c7707e5dc1ff68a5bfd9 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Aug 2024 13:43:11 +0100 Subject: [PATCH 07/18] #2799 - Update observation tests --- .../game_layer/observations/test_file_system_observations.py | 2 ++ .../game_layer/observations/test_node_observations.py | 1 + tests/integration_tests/game_layer/test_observations.py | 1 + 3 files changed, 4 insertions(+) diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 1031dcb0..e2ab2990 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -26,6 +26,7 @@ def test_file_observation(simulation): dog_file_obs = FileObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, + file_system_requires_scan=True, ) assert dog_file_obs.space["health_status"] == spaces.Discrete(6) @@ -53,6 +54,7 @@ def test_folder_observation(simulation): root_folder_obs = FolderObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], include_num_access=False, + file_system_requires_scan=True, num_files=1, files=[], ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 8a36ea5c..1edb0442 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -38,6 +38,7 @@ def test_host_observation(simulation): applications=[], folders=[], network_interfaces=[], + file_system_requires_scan=True, ) assert host_obs.space["operating_status"] == spaces.Discrete(5) diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index ff83c532..d5679007 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -17,6 +17,7 @@ def test_file_observation(): dog_file_obs = FileObservation( where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], include_num_access=False, + file_system_requires_scan=True, ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) From d2693d974f48b9dad4cead272560783b7b420b94 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 7 Aug 2024 13:18:20 +0000 Subject: [PATCH 08/18] Fix relative path to primaite benchmark to align with build pipeline step --- benchmark/primaite_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/primaite_benchmark.py b/benchmark/primaite_benchmark.py index 2b09870d..86ed22a9 100644 --- a/benchmark/primaite_benchmark.py +++ b/benchmark/primaite_benchmark.py @@ -205,7 +205,7 @@ def run( md2pdf( md_path=output_path, pdf_path=str(output_path).replace(".md", ".pdf"), - css_path="benchmark/static/styles.css", + css_path="static/styles.css", ) From df9ab13209c49458d267f4ae01478e9eb9947585 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 9 Aug 2024 09:11:54 +0100 Subject: [PATCH 09/18] #2799 - Fix docstring --- .../game/agent/observations/file_system_observations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index bd130673..1c73d026 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -162,6 +162,9 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): :type num_files: int :param include_num_access: Whether to include the number of accesses to files in the observation. :type include_num_access: bool + :param file_system_requires_scan: If True, the folder must be scanned to update the health state. Tf False, + the true state is always shown. + :type file_system_requires_scan: bool """ self.where: WhereType = where From bf44ceaeac912195b683492d5d0843b9d74de16d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 9 Aug 2024 09:26:37 +0000 Subject: [PATCH 10/18] Apply suggestions from code review --- benchmark/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/report.py b/benchmark/report.py index 408e91cf..4035ceca 100644 --- a/benchmark/report.py +++ b/benchmark/report.py @@ -15,7 +15,7 @@ from utils import _get_system_info import primaite PLOT_CONFIG = { - "size": {"auto_size": False, "width": 800, "height": 800}, + "size": {"auto_size": False, "width": 800, "height": 640}, "template": "plotly_white", "range_slider": False, } From 3df55a708d31f192a8a414673dce3e23e9126486 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 11 Aug 2024 23:24:29 +0100 Subject: [PATCH 11/18] #2769 - add actions and tests for terminal --- src/primaite/game/agent/actions.py | 28 ++- .../system/services/terminal/terminal.py | 120 +++++++------ tests/conftest.py | 7 +- .../actions/test_terminal_actions.py | 165 ++++++++++++++++++ .../test_remote_user_account_actions.py | 49 ------ .../test_user_account_change_password.py | 23 --- 6 files changed, 253 insertions(+), 139 deletions(-) create mode 100644 tests/integration_tests/game_layer/actions/test_terminal_actions.py delete mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py delete mode 100644 tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 2ddeff3d..f421cb0b 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1101,15 +1101,14 @@ class NodeSessionsRemoteLoginAction(AbstractAction): def form_request(self, node_id: str, username: str, password: str, remote_ip: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # TODO: change this so it creates a remote connection using terminal rather than a local remote login node_name = self.manager.get_node_name_by_idx(node_id) return [ "network", "node", node_name, "service", - "UserSessionManager", - "remote_login", + "Terminal", + "ssh_to_remote", username, password, remote_ip, @@ -1122,11 +1121,21 @@ class NodeSessionsRemoteLogoutAction(AbstractAction): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: str, remote_session_id: str) -> RequestFormat: + def form_request(self, node_id: str, remote_ip: str) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # TODO: change this so it destroys a remote connection using terminal rather than a local remote login node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "UserSessionManager", "remote_logout", remote_session_id] + return ["network", "node", node_name, "service", "Terminal", "remote_logoff", remote_ip] + + +class NodeSendRemoteCommandAction(AbstractAction): + """Action which sends a terminal command to a remote node via SSH.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "Terminal", "send_remote_command", remote_ip, command] class ActionManager: @@ -1180,9 +1189,10 @@ class ActionManager: "CONFIGURE_DATABASE_CLIENT": ConfigureDatabaseClientAction, "CONFIGURE_RANSOMWARE_SCRIPT": ConfigureRansomwareScriptAction, "CONFIGURE_DOSBOT": ConfigureDoSBotAction, - "NODE_ACCOUNTS_CHANGEPASSWORD": NodeAccountsChangePasswordAction, - "NODE_SESSIONS_REMOTE_LOGIN": NodeSessionsRemoteLoginAction, - "NODE_SESSIONS_REMOTE_LOGOUT": NodeSessionsRemoteLogoutAction, + "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, + "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, + "SSH_LOGOUT_LOGOUT": NodeSessionsRemoteLogoutAction, + "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 876b1694..ead5c66a 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -92,7 +92,7 @@ class LocalTerminalConnection(TerminalClientConnection): if not self.is_active: self.parent_terminal.sys_log.warning("Connection inactive, cannot execute") return None - return self.parent_terminal.execute(command, connection_id=self.connection_uuid) + return self.parent_terminal.execute(command) class RemoteTerminalConnection(TerminalClientConnection): @@ -162,22 +162,36 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" rm = super()._init_request_manager() - rm.add_request( - "send", - request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), - ) + # rm.add_request( + # "send", + # request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), + # ) - def _login(request: RequestFormat, context: Dict) -> RequestResponse: - login = self._process_local_login(username=request[0], password=request[1]) - if login: - return RequestResponse( - status="success", - data={ - "ip_address": login.ip_address, - }, - ) - else: - return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) + # def _login(request: RequestFormat, context: Dict) -> RequestResponse: + # login = self._process_local_login(username=request[0], password=request[1]) + # if login: + # return RequestResponse( + # status="success", + # data={ + # "ip_address": login.ip_address, + # }, + # ) + # else: + # return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) + # + # rm.add_request( + # "Login", + # request_type=RequestType(func=_login), + # ) + + # def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: + # """Logoff from connection.""" + # connection_uuid = request[0] + # self.parent.user_session_manager.local_logout(connection_uuid) + # self._disconnect(connection_uuid) + # return RequestResponse(status="success", data={}) + # + # rm.add_request("Logoff", request_type=RequestType(func=_logoff)) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) @@ -191,10 +205,34 @@ class Terminal(Service): else: return RequestResponse(status="failure", data={}) + rm.add_request( + "ssh_to_remote", + request_type=RequestType(func=_remote_login), + ) + + def _remote_logoff(request: RequestFormat, context: Dict) -> RequestResponse: + """Logoff from remote connection.""" + ip_address = IPv4Address(request[0]) + remote_connection = self._get_connection_from_ip(ip_address=ip_address) + if remote_connection: + outcome = self._disconnect(remote_connection.connection_uuid) + if outcome: + return RequestResponse( + status="success", + data={}, + ) + else: + return RequestResponse( + status="failure", + data={"reason": "No remote connection held."}, + ) + + rm.add_request("remote_logoff", request_type=RequestType(func=_remote_logoff)) + def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" - command: str = request[0] - ip_address: IPv4Address = IPv4Address(request[1]) + ip_address: IPv4Address = IPv4Address(request[0]) + command: str = request[1] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: outcome = remote_connection.execute(command) @@ -209,30 +247,11 @@ class Terminal(Service): data={}, ) - def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: - """Logoff from connection.""" - connection_uuid = request[0] - self.parent.user_session_manager.local_logout(connection_uuid) - self._disconnect(connection_uuid) - return RequestResponse(status="success", data={}) - rm.add_request( - "Login", - request_type=RequestType(func=_login), - ) - - rm.add_request( - "Remote Login", - request_type=RequestType(func=_remote_login), - ) - - rm.add_request( - "Execute", + "send_remote_command", request_type=RequestType(func=remote_execute_request), ) - rm.add_request("Logoff", request_type=RequestType(func=_logoff)) - return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: @@ -280,13 +299,9 @@ class Terminal(Service): if self.operating_state != ServiceOperatingState.RUNNING: self.sys_log.warning(f"{self.name}: Cannot login as service is not running.") return None - connection_request_id = str(uuid4()) - self._client_connection_requests[connection_request_id] = None if ip_address: # Assuming that if IP is passed we are connecting to remote - return self._send_remote_login( - username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id - ) + return self._send_remote_login(username=username, password=password, ip_address=ip_address) else: return self._process_local_login(username=username, password=password) @@ -320,32 +335,24 @@ class Terminal(Service): username: str, password: str, ip_address: IPv4Address, - connection_request_id: str, + connection_request_id: Optional[str] = None, is_reattempt: bool = False, ) -> Optional[RemoteTerminalConnection]: """Send a remote login attempt and connect to Node. :param: username: Username used to connect to the remote node. :type: username: str - :param: password: Password used to connect to the remote node :type: password: str - :param: ip_address: Target Node IP address for login attempt. :type: ip_address: IPv4Address - - :param: connection_request_id: Connection Request ID - :type: connection_request_id: str - + :param: connection_request_id: Connection Request ID, if not provided, a new one is generated + :type: connection_request_id: Optional[str] :param: is_reattempt: True if the request has been reattempted. Default False. :type: is_reattempt: Optional[bool] - :return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False. - """ - self.sys_log.info( - f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" - ) + connection_request_id = connection_request_id or str(uuid4()) if is_reattempt: valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id) if valid_connection_request: @@ -360,6 +367,9 @@ class Terminal(Service): self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.") return None + self.sys_log.info( + f"{self.name}: Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}" + ) transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password) diff --git a/tests/conftest.py b/tests/conftest.py index d2f9bb2f..2ae6299d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -458,9 +458,10 @@ def game_and_agent(): {"type": "HOST_NIC_DISABLE"}, {"type": "NETWORK_PORT_ENABLE"}, {"type": "NETWORK_PORT_DISABLE"}, - {"type": "NODE_ACCOUNTS_CHANGEPASSWORD"}, - {"type": "NODE_SESSIONS_REMOTE_LOGIN"}, - {"type": "NODE_SESSIONS_REMOTE_LOGOUT"}, + {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, + {"type": "SSH_TO_REMOTE"}, + {"type": "SSH_LOGOUT_LOGOUT"}, + {"type": "NODE_SEND_REMOTE_COMMAND"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py new file mode 100644 index 00000000..ce0810eb --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -0,0 +1,165 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +import pytest + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.base import UserManager +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.system.services.service import ServiceOperatingState +from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_remote_login(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if not connection_established: + pytest.fail("Remote SSH connection could not be established") + + +def test_remote_login_wrong_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "wrong_password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "failure" + + connection_established = False + for conn_str, conn_obj in client_1.terminal.connections.items(): + conn_obj: RemoteTerminalConnection + if conn_obj.ip_address == server_1.network_interface[1].ip_address: + connection_established = True + if connection_established: + pytest.fail("Remote SSH connection was established despite wrong password") + + +def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_um: UserManager = server_1.software_manager.software["UserManager"] + server_1_um.add_user("user123", "password", is_admin=True) + + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + assert server_1_um.users["user123"].password == "different_password" + + +def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + # create a new user account on server_1 that will be logged into remotely + server_1_usm: UserManager = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + # Log in remotely + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + + # Change password + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 1, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + # Assert that the user cannot execute an action + # TODO: should the db conn object get destroyed on both nodes? or is that not realistic? + action = ( + "NODE_SEND_REMOTE_COMMAND", + { + "node_id": 0, + "remote_ip": server_1.network_interface[1].ip_address, + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert server_1.file_system.get_folder("folder123") is None + assert server_1.file_system.get_file("folder123", "doggo.pdf") is None diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py deleted file mode 100644 index 25079226..00000000 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_remote_user_account_actions.py +++ /dev/null @@ -1,49 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.host.computer import Computer - - -def test_remote_logon(game_and_agent): - """Test that the remote session login action works.""" - game, agent = game_and_agent - - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - - client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) - - action = ( - "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0, "username": "test_user", "password": "password", "remote_ip": "10.0.2.2"}, - ) - agent.store_action(action) - game.step() - - assert len(client_1.user_session_manager.remote_sessions) == 1 - - -def test_remote_logoff(game_and_agent): - """Test that the remote session logout action works.""" - game, agent = game_and_agent - - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - - client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) - - action = ( - "NODE_SESSIONS_REMOTE_LOGIN", - {"node_id": 0, "username": "test_user", "password": "password"}, - ) - agent.store_action(action) - game.step() - - assert len(client_1.user_session_manager.remote_sessions) == 1 - - remote_session_id = client_1.user_session_manager.remote_sessions[0].uuid - - action = ( - "NODE_SESSIONS_REMOTE_LOGOUT", - {"node_id": 0, "remote_session_id": remote_session_id}, - ) - agent.store_action(action) - game.step() - - assert len(client_1.user_session_manager.remote_sessions) == 0 diff --git a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py b/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py deleted file mode 100644 index 3e6f55f6..00000000 --- a/tests/integration_tests/game_layer/actions/user_account_actions/test_user_account_change_password.py +++ /dev/null @@ -1,23 +0,0 @@ -# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.host.computer import Computer - - -def test_remote_logon(game_and_agent): - """Test that the remote session login action works.""" - game, agent = game_and_agent - - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - - client_1.user_manager.add_user(username="test_user", password="password", bypass_can_perform_action=True) - user = next((user for user in client_1.user_manager.users.values() if user.username == "test_user"), None) - - assert user.password == "password" - - action = ( - "NODE_ACCOUNTS_CHANGEPASSWORD", - {"node_id": 0, "username": user.username, "current_password": user.password, "new_password": "test_pass"}, - ) - agent.store_action(action) - game.step() - - assert user.password == "test_pass" From 929bd46d6dea2e53d292a26ec765bdb06908d792 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 12 Aug 2024 14:16:04 +0100 Subject: [PATCH 12/18] #2769 - Make changing password disconnect remote sessions --- src/primaite/game/agent/actions.py | 12 +++++++++++- .../simulator/network/hardware/base.py | 18 ++++++++++++++++++ .../system/services/terminal/terminal.py | 7 ++++++- .../actions/test_terminal_actions.py | 8 +++++--- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index f421cb0b..d588c018 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1134,8 +1134,18 @@ class NodeSendRemoteCommandAction(AbstractAction): super().__init__(manager=manager) def form_request(self, node_id: int, remote_ip: str, command: RequestFormat) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" node_name = self.manager.get_node_name_by_idx(node_id) - return ["network", "node", node_name, "service", "Terminal", "send_remote_command", remote_ip, command] + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "send_remote_command", + remote_ip, + {"command": command}, + ] class ActionManager: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 1441c93b..68b45c2e 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -990,6 +990,7 @@ class UserManager(Service): if user and user.password == current_password: user.password = new_password self.sys_log.info(f"{self.name}: Password changed for {username}") + self._user_session_manager._logout_user(user=user) return True self.sys_log.info(f"{self.name}: Password change failed for {username}") return False @@ -1027,6 +1028,10 @@ class UserManager(Service): self.sys_log.info(f"{self.name}: Failed to enable user: {username}") return False + @property + def _user_session_manager(self) -> "UserSessionManager": + return self.software_manager.software["UserSessionManager"] # noqa + class UserSession(SimComponent): """ @@ -1435,6 +1440,19 @@ class UserSessionManager(Service): """ return self._logout(local=False, remote_session_id=remote_session_id) + def _logout_user(self, user: Union[str, User]) -> bool: + """End a user session by username or user object.""" + if isinstance(user, str): + user = self._user_manager.users[user] # grab user object from username + for sess_id, session in self.remote_sessions.items(): + if session.user is user: + self._logout(local=False, remote_session_id=sess_id) + return True + if self.local_user_logged_in and self.local_session.user is user: + self.local_logout() + return True + return False + @property def local_user_logged_in(self) -> bool: """ diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index ead5c66a..79dc698f 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -23,6 +23,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager from primaite.simulator.system.services.service import Service, ServiceOperatingState +# TODO 2824: Since remote terminal connections and remote user sessions are the same thing, we could refactor +# the terminal to leverage the user session manager's list. This way we avoid potential bugs and code ducplication class TerminalClientConnection(BaseModel): """ TerminalClientConnection Class. @@ -232,7 +234,7 @@ class Terminal(Service): def remote_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: """Execute an instruction.""" ip_address: IPv4Address = IPv4Address(request[0]) - command: str = request[1] + command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: outcome = remote_connection.execute(command) @@ -328,6 +330,9 @@ class Terminal(Service): def _check_client_connection(self, connection_id: str) -> bool: """Check that client_connection_id is valid.""" + if not self.parent.user_session_manager.validate_remote_session_uuid(connection_id): + self._disconnect(connection_id) + return False return connection_id in self._connections def _send_remote_login( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index ce0810eb..84d21bb0 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -8,6 +8,8 @@ from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.base import UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.service import ServiceOperatingState from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection @@ -17,8 +19,8 @@ def game_and_agent_fixture(game_and_agent): """Create a game with a simple agent that can be controlled by the tests.""" game, agent = game_and_agent - client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") - client_1.start_up_duration = 3 + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) return (game, agent) @@ -154,7 +156,7 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam "NODE_SEND_REMOTE_COMMAND", { "node_id": 0, - "remote_ip": server_1.network_interface[1].ip_address, + "remote_ip": str(server_1.network_interface[1].ip_address), "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], }, ) From 1d2705eb1b95bb369926bc5858feef7a56cf4abe Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 15 Aug 2024 20:16:11 +0100 Subject: [PATCH 13/18] #2769 - Add user login observations --- .../observations/firewall_observation.py | 20 +- .../agent/observations/host_observations.py | 21 + .../agent/observations/node_observations.py | 8 + .../agent/observations/router_observation.py | 17 +- .../simulator/network/hardware/base.py | 3 +- tests/assets/configs/data_manipulation.yaml | 942 ++++++++++++++++++ .../observations/test_user_observations.py | 89 ++ 7 files changed, 1096 insertions(+), 4 deletions(-) create mode 100644 tests/assets/configs/data_manipulation.yaml create mode 100644 tests/integration_tests/game_layer/observations/test_user_observations.py diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 4f1a9d90..42ceaff0 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -10,6 +10,7 @@ from primaite import getLogger from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -32,6 +33,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -41,6 +44,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list: List[int], protocol_list: List[str], num_rules: int, + include_users: bool, ) -> None: """ Initialise a firewall observation instance. @@ -58,9 +62,13 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :type protocol_list: List[str] :param num_rules: Number of rules configured in the firewall. :type num_rules: int + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" self.ports: List[PortObservation] = [ PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] @@ -142,6 +150,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. :rtype: ObsType """ + firewall_state = access_from_nested_dict(state, self.where) + if firewall_state is NOT_PRESENT_IN_STATE: + return self.default_observation obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, "ACL": { @@ -159,6 +170,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): }, }, } + if self.include_users: + sess = firewall_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -218,4 +235,5 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 7053d019..4419ccc7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -52,6 +52,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -69,6 +71,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + include_users: bool, ) -> None: """ Initialise a host observation instance. @@ -103,10 +106,15 @@ class HostObservation(AbstractObservation, identifier="HOST"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.include_num_access = include_num_access + self.include_users = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services @@ -165,6 +173,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: self.default_observation["num_file_creations"] = 0 self.default_observation["num_file_deletions"] = 0 + if self.include_users: + self.default_observation["users"] = {"local_login": 0, "remote_sessions": 0} def observe(self, state: Dict) -> ObsType: """ @@ -192,6 +202,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_users: + sess = node_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -216,6 +232,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): if self.include_num_access: shape["num_file_creations"] = spaces.Discrete(4) shape["num_file_deletions"] = spaces.Discrete(4) + if self.include_users: + shape["users"] = spaces.Dict( + {"local_login": spaces.Discrete(2), "remote_sessions": spaces.Discrete(self.max_users + 1)} + ) return spaces.Dict(shape) @classmethod @@ -273,4 +293,5 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index c68531f8..e263cadb 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -46,6 +46,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """Flag to include the number of accesses.""" file_system_requires_scan: bool = True """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + include_users: Optional[bool] = True + """If True, report user session information.""" num_ports: Optional[int] = None """Number of ports.""" ip_list: Optional[List[str]] = None @@ -191,6 +193,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.include_users is None: + host_config.include_users = config.include_users for router_config in config.routers: if router_config.num_ports is None: @@ -205,6 +209,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.protocol_list = config.protocol_list if router_config.num_rules is None: router_config.num_rules = config.num_rules + if router_config.include_users is None: + router_config.include_users = config.include_users for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -217,6 +223,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.protocol_list = config.protocol_list if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules + if firewall_config.include_users is None: + firewall_config.include_users = config.include_users hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index f1d4ec8e..d064936a 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -39,6 +39,8 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + include_users: Optional[bool] = True + """If True, report user session information.""" def __init__( self, @@ -46,6 +48,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports: List[PortObservation], num_ports: int, acl: ACLObservation, + include_users: bool, ) -> None: """ Initialise a router observation instance. @@ -59,12 +62,16 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :type num_ports: int :param acl: ACL observation representing the access control list of the router. :type acl: ACLObservation + :param include_users: If True, report user session information. + :type include_users: bool """ self.where: WhereType = where self.ports: List[PortObservation] = ports self.acl: ACLObservation = acl self.num_ports: int = num_ports - + self.include_users: bool = include_users + self.max_users: int = 3 + """Maximum number of remote sessions observable, excess sessions are truncated.""" while len(self.ports) < num_ports: self.ports.append(PortObservation(where=None)) while len(self.ports) > num_ports: @@ -95,6 +102,12 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): obs["ACL"] = self.acl.observe(state) if self.ports: obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + if self.include_users: + sess = router_state["services"]["UserSessionManager"] + obs["users"] = { + "local_login": 1 if sess["current_local_user"] else 0, + "remote_sessions": min(self.max_users, len(sess["active_remote_sessions"])), + } return obs @property @@ -143,4 +156,4 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl, include_users=config.include_users) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 68b45c2e..b0c48e7d 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1265,7 +1265,8 @@ class UserSessionManager(Service): :return: A dictionary representing the current state. """ state = super().describe_state() - state["active_remote_logins"] = len(self.remote_sessions) + state["current_local_user"] = None if not self.local_session else self.local_session.user.username + state["active_remote_sessions"] = list(self.remote_sessions.keys()) return state @property diff --git a/tests/assets/configs/data_manipulation.yaml b/tests/assets/configs/data_manipulation.yaml new file mode 100644 index 00000000..97442903 --- /dev/null +++ b/tests/assets/configs/data_manipulation.yaml @@ -0,0 +1,942 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + sys_log_level: WARNING + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_FIX + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: ROUTER_ACL_ADDRULE + - type: ROUTER_ACL_REMOVERULE + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_FIX" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_CHECKHASH" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 50: # old action num: 26 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 51: # old action num: 27 + action: "ROUTER_ACL_ADDRULE" + options: + target_router: router_1 + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + source_wildcard_id: 0 + dest_wildcard_id: 0 + 52: # old action num: 28 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 0 + 53: # old action num: 29 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 1 + 54: # old action num: 30 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 2 + 55: # old action num: 31 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 3 + 56: # old action num: 32 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 4 + 57: # old action num: 33 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 5 + 58: # old action num: 34 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 6 + 59: # old action num: 35 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 7 + 60: # old action num: 36 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 8 + 61: # old action num: 37 + action: "ROUTER_ACL_REMOVERULE" + options: + target_router: router_1 + position: 9 + 62: # old action num: 38 + action: "HOST_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "HOST_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "HOST_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "HOST_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "HOST_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "HOST_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "HOST_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "HOST_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "HOST_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "HOST_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "HOST_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "HOST_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "HOST_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "HOST_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + agent_settings: + flatten_obs: true + action_masking: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - hostname: switch_1 + type: switch + num_ports: 8 + + - hostname: switch_2 + type: switch + num_ports: 8 + + - hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: WebServer + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - type: FTPClient + + - hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - type: FTPServer + + - hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - type: DNSClient + + links: + - endpoint_a_hostname: router_1 + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 8 + - endpoint_a_hostname: router_1 + endpoint_a_port: 2 + endpoint_b_hostname: switch_2 + endpoint_b_port: 8 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: domain_controller + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: web_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 3 + endpoint_b_hostname: database_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 4 + endpoint_b_hostname: backup_server + endpoint_b_port: 1 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + - endpoint_a_hostname: switch_2 + endpoint_a_port: 7 + endpoint_b_hostname: security_suite + endpoint_b_port: 2 diff --git a/tests/integration_tests/game_layer/observations/test_user_observations.py b/tests/integration_tests/game_layer/observations/test_user_observations.py new file mode 100644 index 00000000..ca5e2543 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_user_observations.py @@ -0,0 +1,89 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +@pytest.fixture +def env_with_ssh() -> PrimaiteGymEnv: + """Build data manipulation environment with SSH port open on router.""" + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.agent.flatten_obs = False + router: Router = env.game.simulation.network.get_node_by_hostname("router_1") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=3) + return env + + +def extract_login_numbers_from_obs(obs): + """Traverse the observation dictionary and return number of user sessions for all nodes.""" + login_nums = {} + for node_name, node_obs in obs["NODES"].items(): + login_nums[node_name] = node_obs.get("users") + return login_nums + + +class TestUserObservations: + """Test that the RouterObservation, FirewallObservation, and HostObservation have the correct number of logins.""" + + def test_no_sessions_at_episode_start(self, env_with_ssh): + """Test that all of the login observations start at 0 before any logins occur.""" + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_single_login(self, env_with_ssh: PrimaiteGymEnv): + """Test that performing a remote login increases the remote_sessions observation by 1.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + assert db_srv_logins_obs["local_login"] == 0 + assert db_srv_logins_obs["remote_sessions"] == 1 + for o in logins_obs.values(): # the remaining obs after popping HOST2 + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_logout(self, env_with_ssh: PrimaiteGymEnv): + """Test that remote_sessions observation correctly decreases upon logout.""" + client_1 = env_with_ssh.game.simulation.network.get_node_by_hostname("client_1") + client_1.terminal._send_remote_login("admin", "admin", "192.168.1.14") # connect to database server via ssh + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_manager.change_user_password("admin", "admin", "different_pass") # changing password logs out user + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + for o in logins_obs.values(): + assert o["local_login"] == 0 + assert o["remote_sessions"] == 0 + + def test_max_observable_sessions(self, env_with_ssh: PrimaiteGymEnv): + """Log in from 5 remote places and check that only a max of 3 is shown in the observation.""" + MAX_OBSERVABLE_SESSIONS = 3 + # Right now this is hardcoded as 3 in HostObservation, FirewallObservation, and RouterObservation + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + db_srv = env_with_ssh.game.simulation.network.get_node_by_hostname("database_server") + db_srv.user_session_manager.remote_session_timeout_steps = 20 + db_srv.user_session_manager.max_remote_sessions = 5 + node_names = ("client_1", "client_2", "backup_server", "security_suite", "domain_controller") + + for i, node_name in enumerate(node_names): + node = env_with_ssh.game.simulation.network.get_node_by_hostname(node_name) + node.terminal._send_remote_login("admin", "admin", "192.168.1.14") + + obs, *_ = env_with_ssh.step(0) + logins_obs = extract_login_numbers_from_obs(obs) + db_srv_logins_obs = logins_obs.pop("HOST2") # this is the index of db server + + assert db_srv_logins_obs["remote_sessions"] == min(MAX_OBSERVABLE_SESSIONS, i + 1) + assert len(db_srv.user_session_manager.remote_sessions) == i + 1 From 21c0b02ff79ed4c469baa2874be543cfb8e94fc0 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 16 Aug 2024 09:21:27 +0100 Subject: [PATCH 14/18] #2769 - update observation tests with new parameter --- .../game_layer/observations/test_firewall_observation.py | 1 + .../game_layer/observations/test_node_observations.py | 1 + .../game_layer/observations/test_router_observation.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 99417e33..34a37f5e 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -33,6 +33,7 @@ def test_firewall_observation(): wildcard_list=["0.0.0.255", "0.0.0.1"], port_list=["HTTP", "DNS"], protocol_list=["TCP"], + include_users=False, ) observation = firewall_observation.observe(firewall.describe_state()) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 1edb0442..69d9f106 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,7 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + include_users=False, ) assert host_obs.space["operating_status"] == spaces.Discrete(5) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index c534307f..48d29cfb 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -27,7 +27,7 @@ def test_router_observation(): port_list=["HTTP", "DNS"], protocol_list=["TCP"], ) - router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl) + router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl, include_users=False) # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) From d74227e34f663799ef28ccb71062ebf690eb76ca Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 16 Aug 2024 10:10:26 +0100 Subject: [PATCH 15/18] #2769 - update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c63b114..8ac61df4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. +- Added actions to establish SSH connections, send commands remotely and terminate SSH connections. +- Added actions to change users' passwords. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. +- Node observations can now be configured to show the number of active local and remote logins. ### Fixed - Folder observations showing the true health state without scanning (the old behaviour can be reenabled via config) From aeca5fb6a27efb77dd588e78f6815d056c5db8da Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 10:28:39 +0100 Subject: [PATCH 16/18] #2769 - Clean up incorrect names and commented out code [skip ci] --- src/primaite/game/agent/actions.py | 2 +- .../system/services/terminal/terminal.py | 30 ------------------- tests/conftest.py | 2 +- 3 files changed, 2 insertions(+), 32 deletions(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index d588c018..2a0c5351 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1201,7 +1201,7 @@ class ActionManager: "CONFIGURE_DOSBOT": ConfigureDoSBotAction, "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, - "SSH_LOGOUT_LOGOUT": NodeSessionsRemoteLogoutAction, + "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 79dc698f..406facd1 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -164,36 +164,6 @@ class Terminal(Service): def _init_request_manager(self) -> RequestManager: """Initialise Request manager.""" rm = super()._init_request_manager() - # rm.add_request( - # "send", - # request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())), - # ) - - # def _login(request: RequestFormat, context: Dict) -> RequestResponse: - # login = self._process_local_login(username=request[0], password=request[1]) - # if login: - # return RequestResponse( - # status="success", - # data={ - # "ip_address": login.ip_address, - # }, - # ) - # else: - # return RequestResponse(status="failure", data={"reason": "Invalid login credentials"}) - # - # rm.add_request( - # "Login", - # request_type=RequestType(func=_login), - # ) - - # def _logoff(request: RequestFormat, context: Dict) -> RequestResponse: - # """Logoff from connection.""" - # connection_uuid = request[0] - # self.parent.user_session_manager.local_logout(connection_uuid) - # self._disconnect(connection_uuid) - # return RequestResponse(status="success", data={}) - # - # rm.add_request("Logoff", request_type=RequestType(func=_logoff)) def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse: login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2]) diff --git a/tests/conftest.py b/tests/conftest.py index 2ae6299d..abc851c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -460,7 +460,7 @@ def game_and_agent(): {"type": "NETWORK_PORT_DISABLE"}, {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, {"type": "SSH_TO_REMOTE"}, - {"type": "SSH_LOGOUT_LOGOUT"}, + {"type": "SESSIONS_REMOTE_LOGOFF"}, {"type": "NODE_SEND_REMOTE_COMMAND"}, ] From a997cebbc69ce0196e7843ec1349057275991f55 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 19 Aug 2024 11:14:53 +0000 Subject: [PATCH 17/18] Apply suggestions from code review [skip ci] --- .../game_layer/actions/test_terminal_actions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index 84d21bb0..d011c1e8 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -151,7 +151,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam game.step() # Assert that the user cannot execute an action - # TODO: should the db conn object get destroyed on both nodes? or is that not realistic? action = ( "NODE_SEND_REMOTE_COMMAND", { From 2c71958c913d4186474ea2896c30bb23f56c888c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 19 Aug 2024 12:55:45 +0100 Subject: [PATCH 18/18] #2748: Port of PrimAITE Internal changes. --- CHANGELOG.md | 1 + src/primaite/game/agent/interface.py | 2 + src/primaite/game/agent/rewards.py | 61 +++++++++++++++---- .../game_layer/test_rewards.py | 21 +++---- 4 files changed, 63 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c63b114..7daf1f60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. - Added a `listen_on_ports` set in the `IOSoftware` class to enable software listening on ports in addition to the main port they're assigned. +- Added reward calculation details to AgentHistoryItem. ### Changed - File and folder observations can now be configured to always show the true health status, or require scanning like before. diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index f57dc191..14b97821 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -36,6 +36,8 @@ class AgentHistoryItem(BaseModel): reward: Optional[float] = None + reward_info: Dict[str, Any] = {} + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index c959ee5b..b913501d 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -47,7 +47,15 @@ class AbstractReward: @abstractmethod def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -67,7 +75,15 @@ class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the reward for the current state.""" + """Calculate the reward for the current state. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return 0.0 @classmethod @@ -109,8 +125,12 @@ class DatabaseFileIntegrity(AbstractReward): def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: """Calculate the reward for the current state. - :param state: The current state of the simulation. + :param state: Current simulation state :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: @@ -283,6 +303,12 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): component will keep track of that information. In that case, it doesn't matter whether the last successful request returned was able to connect to the database server, because there has been an unsuccessful request since. + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float """ if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: self._last_request_failed = last_action_response.response.status != "success" @@ -295,14 +321,11 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): # If the last request was actually sent, then check if the connection was established. db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: - _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + last_action_response.reward_info = {"reason": f"Can't calculate reward for {self.__class__.__name__}"} return 0.0 last_connection_successful = db_state["last_connection_successful"] - if last_connection_successful is False: - return -1.0 - elif last_connection_successful is True: - return 1.0 - return 0.0 + last_action_response.reward_info = {"last_connection_successful": last_connection_successful} + return 1.0 if last_connection_successful else -1.0 @classmethod def from_config(cls, config: Dict) -> AbstractReward: @@ -346,7 +369,15 @@ class SharedReward(AbstractReward): """Method that retrieves an agent's current reward given the agent's name.""" def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Simply access the other agent's reward and return it.""" + """Simply access the other agent's reward and return it. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ return self.callback(self.agent_name) @classmethod @@ -379,7 +410,15 @@ class ActionPenalty(AbstractReward): self.do_nothing_penalty = do_nothing_penalty def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float: - """Calculate the penalty to be applied.""" + """Calculate the penalty to be applied. + + :param state: Current simulation state + :type state: Dict + :param last_action_response: Current agent history state + :type last_action_response: AgentHistoryItem state + :return: Reward value + :rtype: float + """ if last_action_response.action == "DONOTHING": return self.do_nothing_penalty else: diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index 2bf551c8..e945f482 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -76,13 +76,16 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), + ahi = AgentHistoryItem( + timestep=0, + action="NODE_APPLICATION_EXECUTE", + parameters={}, + request=["execute"], + response=response, ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == 1.0 + assert ahi.reward_info == {"last_connection_successful": True} router.acl.remove_rule(position=2) @@ -92,13 +95,9 @@ def test_uc2_rewards(game_and_agent): ] ) state = game.get_sim_state() - reward_value = comp.calculate( - state, - last_action_response=AgentHistoryItem( - timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response - ), - ) + reward_value = comp.calculate(state, last_action_response=ahi) assert reward_value == -1.0 + assert ahi.reward_info == {"last_connection_successful": False} def test_shared_reward():