From 9d40e95982fa02935aa00a26d68abc3414803b60 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 2 Aug 2024 13:48:12 +0100 Subject: [PATCH 01/72] precommit json end of file fixes --- benchmark/results/v3/v3.2.0/session_metadata/1.json | 2 +- benchmark/results/v3/v3.2.0/session_metadata/2.json | 2 +- benchmark/results/v3/v3.2.0/session_metadata/3.json | 2 +- benchmark/results/v3/v3.2.0/session_metadata/4.json | 2 +- benchmark/results/v3/v3.2.0/session_metadata/5.json | 2 +- benchmark/results/v3/v3.2.0/v3.2.0_benchmark_metadata.json | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmark/results/v3/v3.2.0/session_metadata/1.json b/benchmark/results/v3/v3.2.0/session_metadata/1.json index 794f03e3..bfccfcdc 100644 --- a/benchmark/results/v3/v3.2.0/session_metadata/1.json +++ b/benchmark/results/v3/v3.2.0/session_metadata/1.json @@ -1006,4 +1006,4 @@ "999": 78.49999999999996, "1000": 84.69999999999993 } -} \ No newline at end of file +} diff --git a/benchmark/results/v3/v3.2.0/session_metadata/2.json b/benchmark/results/v3/v3.2.0/session_metadata/2.json index e48c34b9..c35b5ae6 100644 --- a/benchmark/results/v3/v3.2.0/session_metadata/2.json +++ b/benchmark/results/v3/v3.2.0/session_metadata/2.json @@ -1006,4 +1006,4 @@ "999": 97.59999999999975, "1000": 103.34999999999978 } -} \ No newline at end of file +} diff --git a/benchmark/results/v3/v3.2.0/session_metadata/3.json b/benchmark/results/v3/v3.2.0/session_metadata/3.json index 4e2d845c..342e0f7d 100644 --- a/benchmark/results/v3/v3.2.0/session_metadata/3.json +++ b/benchmark/results/v3/v3.2.0/session_metadata/3.json @@ -1006,4 +1006,4 @@ "999": 101.14999999999978, "1000": 80.94999999999976 } -} \ No newline at end of file +} diff --git a/benchmark/results/v3/v3.2.0/session_metadata/4.json b/benchmark/results/v3/v3.2.0/session_metadata/4.json index 6e03a18f..6aaf9ab8 100644 --- a/benchmark/results/v3/v3.2.0/session_metadata/4.json +++ b/benchmark/results/v3/v3.2.0/session_metadata/4.json @@ -1006,4 +1006,4 @@ "999": 118.0500000000001, "1000": 77.95000000000005 } -} \ No newline at end of file +} diff --git a/benchmark/results/v3/v3.2.0/session_metadata/5.json b/benchmark/results/v3/v3.2.0/session_metadata/5.json index ca7ad1e9..05cf76ed 100644 --- a/benchmark/results/v3/v3.2.0/session_metadata/5.json +++ b/benchmark/results/v3/v3.2.0/session_metadata/5.json @@ -1006,4 +1006,4 @@ "999": 55.849999999999916, "1000": 96.95000000000007 } -} \ No newline at end of file +} diff --git a/benchmark/results/v3/v3.2.0/v3.2.0_benchmark_metadata.json b/benchmark/results/v3/v3.2.0/v3.2.0_benchmark_metadata.json index 830e980e..111ae25f 100644 --- a/benchmark/results/v3/v3.2.0/v3.2.0_benchmark_metadata.json +++ b/benchmark/results/v3/v3.2.0/v3.2.0_benchmark_metadata.json @@ -7442,4 +7442,4 @@ } } } -} \ No newline at end of file +} From 0ff88e36726ff7e652047ec6e3a78dc46576c6d3 Mon Sep 17 00:00:00 2001 From: Archer Bowen Date: Mon, 2 Sep 2024 11:50:49 +0100 Subject: [PATCH 02/72] #2840 Initial Implementation completed and tested. --- src/primaite/game/agent/actions.py | 23 +++++++++ .../system/services/terminal/terminal.py | 22 +++++++- tests/conftest.py | 1 + .../actions/test_terminal_actions.py | 51 +++++++++++++++++++ 4 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 2e6189c0..3dc1f514 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1266,6 +1266,28 @@ class NodeSendRemoteCommandAction(AbstractAction): ] +class NodeSendLocalCommandAction(AbstractAction): + """Action which sends a terminal command using a local terminal session.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: int, username: str, password: str, command: RequestFormat) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "Terminal", + "send_local_command", + username, + password, + {"command": command}, + ] + + class TerminalC2ServerAction(AbstractAction): """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" @@ -1372,6 +1394,7 @@ class ActionManager: "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, "NODE_SEND_REMOTE_COMMAND": NodeSendRemoteCommandAction, + "NODE_SEND_LOCAL_COMMAND": NodeSendLocalCommandAction, } """Dictionary which maps action type strings to the corresponding action class.""" diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index e98e8555..9b88bbe8 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -208,7 +208,6 @@ class Terminal(Service): status="success", data={}, ) - else: return RequestResponse( status="failure", data={}, @@ -219,6 +218,27 @@ class Terminal(Service): request_type=RequestType(func=remote_execute_request), ) + def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: + """Executes a command using a local terminal session.""" + command: str = request[2]["command"] + local_connection = self._process_local_login(username=request[0], password=request[1]) + if local_connection: + outcome = local_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={"reason": outcome}, + ) + return RequestResponse( + status="success", + data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"}, + ) + + rm.add_request( + "send_local_command", + request_type=RequestType(func=local_execute_request), + ) + return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: diff --git a/tests/conftest.py b/tests/conftest.py index 1bbff8f2..8717abfa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -467,6 +467,7 @@ def game_and_agent(): {"type": "SSH_TO_REMOTE"}, {"type": "SESSIONS_REMOTE_LOGOFF"}, {"type": "NODE_SEND_REMOTE_COMMAND"}, + {"type": "NODE_SEND_LOCAL_COMMAND"}, ] action_space = ActionManager( diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index d011c1e8..d2ea7202 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -164,3 +164,54 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam assert server_1.file_system.get_folder("folder123") is None assert server_1.file_system.get_file("folder123", "doggo.pdf") is None + + +def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + # create a new user account on server_1 that will be logged into remotely + client_1_usm: UserManager = client_1.software_manager.software["UserManager"] + client_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "NODE_SEND_LOCAL_COMMAND", + { + "node_id": 0, + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_folder("folder123") + assert client_1.file_system.get_file("folder123", "doggo.pdf") + + # Change password + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + { + "node_id": 0, # server_1 + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + action = ( + "NODE_SEND_LOCAL_COMMAND", + { + "node_id": 0, + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "cat.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_file("folder123", "cat.pdf") is None From a7f00c668dc75932f6cb72de9f8709ce672b58f2 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 2 Sep 2024 15:15:45 +0100 Subject: [PATCH 03/72] #2782: initial impl of files in nodes --- src/primaite/game/game.py | 6 + .../configs/nodes_with_initial_files.yaml | 256 ++++++++++++++++++ .../test_node_file_system_config.py | 47 ++++ 3 files changed, 309 insertions(+) create mode 100644 tests/assets/configs/nodes_with_initial_files.yaml create mode 100644 tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..befa4032 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -329,6 +329,12 @@ class PrimaiteGame: _LOGGER.error(msg) raise ValueError(msg) + # handle node file system + if node_cfg.get("file_system") is not None and len(node_cfg.get("file_system")) > 0: + for folder in node_cfg.get("file_system"): + for file in node_cfg["file_system"][folder]: + new_node.file_system.create_file(folder_name=folder, file_name=file) + if "users" in node_cfg and new_node.software_manager.software.get("UserManager"): user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa for user_cfg in node_cfg["users"]: diff --git a/tests/assets/configs/nodes_with_initial_files.yaml b/tests/assets/configs/nodes_with_initial_files.yaml new file mode 100644 index 00000000..3213098b --- /dev/null +++ b/tests/assets/configs/nodes_with_initial_files.yaml @@ -0,0 +1,256 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + + + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: client_1 + - hostname: client_2 + - hostname: client_3 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + include_nmne: false + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - switch_1:eth-1<->client_1:eth-1 + - switch_1:eth-2<->client_2:eth-1 + - type: "NONE" + label: ICS + options: {} + + action_space: + action_list: + - type: DONOTHING + + action_map: + 0: + action: DONOTHING + options: {} + options: + nodes: + - node_name: switch + - node_name: client_1 + - node_name: client_2 + - node_name: client_3 + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.5 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + + - type: WEB_SERVER_404_PENALTY + weight: 0.5 + options: + node_hostname: web_server + service_name: web_server_web_service + + + agent_settings: + flatten_obs: true + +simulation: + network: + nodes: + + - type: switch + hostname: switch_1 + num_ports: 8 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: RansomwareScript + - type: WebBrowser + options: + target_url: http://arcd.com/users/ + - type: DatabaseClient + options: + db_server_ip: 192.168.1.10 + server_password: arcd + - type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + - type: DoSBot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + services: + - type: DNSClient + options: + dns_server: 192.168.1.10 + - type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.10 + - type: DatabaseService + options: + backup_server_ip: 192.168.1.10 + - type: WebServer + - type: FTPServer + options: + server_password: arcd + - type: NTPClient + options: + ntp_server_ip: 192.168.1.10 + - type: NTPServer + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + file_system: + downloads: + - "test.txt" + - "suh_con.dn" + root: + - "passwords.txt" + # pre installed services and applications + - hostname: client_3 + type: computer + ip_address: 192.168.10.23 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + start_up_duration: 0 + shut_down_duration: 0 + operating_state: "OFF" + # pre installed services and applications + + links: + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + bandwidth: 200 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py new file mode 100644 index 00000000..05ef7275 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py @@ -0,0 +1,47 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from pathlib import Path +from typing import Union + +import yaml + +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_node_file_system_from_config(): + """Test that the appropriate files are instantiated in nodes when loaded from config.""" + game = load_config(BASIC_CONFIG) + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert client_1.software_manager.software.get("DatabaseService") # database service should be installed + assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist + + assert client_1.software_manager.software.get("WebServer") # web server should be installed + assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist + + client_2 = game.simulation.network.get_node_by_hostname("client_2") + + # database service should not be installed + assert client_2.software_manager.software.get("DatabaseService") is None + # database files should not exist + assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None + + # web server should not be installed + assert client_2.software_manager.software.get("WebServer") is None + # web files should not exist + assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None + + # TODO file sizes and file types + # TODO assert that files and folders created: + # TODO create empty folders From 5cacbf03373bccd634c8086783222bb45648871a Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 2 Sep 2024 16:54:13 +0100 Subject: [PATCH 04/72] #2845: Changes to write observation space data to log file. --- src/primaite/session/environment.py | 25 +++++++++++++++++++++++++ src/primaite/session/io.py | 2 ++ 2 files changed, 27 insertions(+) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index c66663e3..23b86546 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -112,6 +112,9 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.update_agents(state) next_obs = self._get_obs() # this doesn't update observation, just gets the current observation + if self.io.settings.obs_space_data: + # Write unflattened observation space to log file. + self._write_obs_space_data(self.agent.observation_manager.current_observation) reward = self.agent.reward_function.current_reward _LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}") terminated = False @@ -139,6 +142,25 @@ class PrimaiteGymEnv(gymnasium.Env): with open(path, "w") as file: json.dump(data, file) + def _write_obs_space_data(self, obs_space: ObsType) -> None: + """Write the unflattened observation space data to a JSON file. + + :param obs: Observation of the environment (dict) + :type obs: ObsType + """ + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "obs_space_data" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{self.game.step_counter}.json" + + data = { + "episode": self.episode_counter, + "step": self.game.step_counter, + "obs_space_data": obs_space, + } + with open(path, "w") as file: + json.dump(data, file) + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" _LOGGER.info( @@ -159,6 +181,9 @@ class PrimaiteGymEnv(gymnasium.Env): state = self.game.get_sim_state() self.game.update_agents(state=state) next_obs = self._get_obs() + if self.io.settings.obs_space_data: + # Write unflattened observation space to log file. + self._write_obs_space_data(self.agent.observation_manager.current_observation) info = {} return next_obs, info diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 78d7cb3c..3627e9e9 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -45,6 +45,8 @@ class PrimaiteIO: """The level of sys logs that should be included in the logfiles/logged into terminal.""" agent_log_level: LogLevel = LogLevel.INFO """The level of agent logs that should be included in the logfiles/logged into terminal.""" + obs_space_data: bool = False + """Whether to save observation space data to a log file.""" def __init__(self, settings: Optional[Settings] = None) -> None: """ From fd3d3812f6d8d03b1d44261b06ff39fecd0e2209 Mon Sep 17 00:00:00 2001 From: Archer Bowen Date: Mon, 2 Sep 2024 16:55:43 +0100 Subject: [PATCH 05/72] #2840 Documentation and minor bug fixes found in terminal and session manager. --- CHANGELOG.md | 2 + .../system/services/terminal.rst | 106 ++++++- .../notebooks/Terminal-Processing.ipynb | 274 +++++++++++++++++- .../simulator/system/core/session_manager.py | 2 +- .../system/services/terminal/terminal.py | 8 +- .../actions/test_terminal_actions.py | 1 + 6 files changed, 385 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d08974c..2a855512 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) + ### Added - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index f982145d..6a1b0204 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -30,6 +30,7 @@ Usage - Terminal Clients connect, execute commands and disconnect from remote nodes. - Ensures that users are logged in to the component before executing any commands. - Service runs on SSH port 22 by default. + - Enables Agents to send commands both remotely and locally. Implementation """""""""""""" @@ -39,9 +40,110 @@ Implementation - Extends Service class. - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. +Command Format +^^^^^^^^^^^^^^ + +``Terminals`` implement their commands through leveraging the pre-existing :doc:`../../request_system`. + +Due to this ``Terminals`` will only accept commands passed within the ``RequestFormat``. + +:py:class:`primaite.game.interface.RequestFormat` + +For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows: + +.. code-block:: yaml + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" + +**This command creates file called ``cat.png`` within the ``downloads`` folder.** + +This is then loaded from ``yaml`` into a dictionary containing the terminal command: + +.. code-block:: python + + {"command":["file_system", "create", "file", "downloads", "cat.png", "False"]} + +Which is then parsed to the ``Terminals`` Request Manager to be executed. + +Game Layer Usage (Agents) +======================== + +The below code examples demonstrate how to use terminal related actions in yaml files. + +yaml +"""" + +``NODE_SEND_LOCAL_COMMAND`` +""""""""""""""""""""""""""" + +Agents can execute local commands without needing to perform a separate remote login action (``SSH_TO_REMOTE``). + +.. code-block:: yaml + + ... + ... + action: NODE_SEND_LOCAL_COMMAND + options: + node_id: 0 + username: admin + password: admin + command: # Example command - Creates a file called 'cat.png' in the downloads folder. + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" + + +``SSH_TO_REMOTE`` +""""""""""""""""" + +Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts. + +.. code-block:: yaml + + ... + ... + action: SSH_TO_REMOTE + options: + node_id: 0 + username: admin + password: admin + remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh) + + +``NODE_SEND_REMOTE_COMMAND`` +"""""""""""""""""""""""""""" + +After remotely login into another host, a agent can use the ``NODE_SEND_REMOTE_COMMAND`` to execute commands across the network remotely. + +.. code-block:: yaml + + ... + ... + action: NODE_SEND_REMOTE_COMMAND + options: + node_id: 0 + remote_ip: 192.168.0.10 + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" + + + +Simulation Layer Usage +====================== -Usage -===== The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node. diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index fdf405a7..19ce567e 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -9,6 +9,13 @@ "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simulation Layer Implementation." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -198,6 +205,271 @@ "source": [ "computer_b.user_session_manager.show(include_historic=True, include_session_id=True)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Game Layer Implementation\n", + "\n", + "This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n", + "\n", + "The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n", + "\n", + "\n", + "| Game Layer Action | Simulation Layer |\n", + "|-----------------------------------|--------------------------|\n", + "| ``NODE_SEND_LOCAL_COMMAND`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n", + "| ``SSH_TO_REMOTE`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n", + "| ``NODE_SEND_REMOTE_COMMAND`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Game Layer Setup\n", + "\n", + "Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n", + "\n", + "If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "custom_terminal_agent = \"\"\"\n", + " - ref: CustomC2Agent\n", + " team: RED\n", + " type: ProxyAgent\n", + " observation_space: null\n", + " action_space:\n", + " action_list:\n", + " - type: DONOTHING\n", + " - type: NODE_SEND_LOCAL_COMMAND\n", + " - type: SSH_TO_REMOTE\n", + " - type: NODE_SEND_REMOTE_COMMAND\n", + " options:\n", + " nodes:\n", + " - node_name: client_1\n", + " max_folders_per_node: 1\n", + " max_files_per_folder: 1\n", + " max_services_per_node: 2\n", + " max_nics_per_node: 8\n", + " max_acl_rules: 10\n", + " ip_list:\n", + " - 192.168.1.21\n", + " - 192.168.1.14\n", + " wildcard_list:\n", + " - 0.0.0.1\n", + " action_map:\n", + " 0:\n", + " action: DONOTHING\n", + " options: {}\n", + " 1:\n", + " action: NODE_SEND_LOCAL_COMMAND\n", + " options:\n", + " node_id: 0\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + " 2:\n", + " action: SSH_TO_REMOTE\n", + " options:\n", + " node_id: 0\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22\n", + " 3:\n", + " action: NODE_SEND_REMOTE_COMMAND\n", + " options:\n", + " node_id: 0\n", + " remote_ip: 192.168.10.22\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + " reward_function:\n", + " reward_components:\n", + " - type: DUMMY\n", + "\"\"\"\n", + "custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path()) as f:\n", + " cfg = yaml.safe_load(f)\n", + " # removing all agents & adding the custom agent.\n", + " cfg['agents'] = {}\n", + " cfg['agents'] = custom_terminal_agent_yaml\n", + " \n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", + "client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``NODE_SEND_LOCAL_COMMAND`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: NODE_SEND_LOCAL_COMMAND\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: NODE_SEND_LOCAL_COMMAND\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(1)\n", + "client_1.file_system.show(full=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``SSH_TO_REMOTE`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: SSH_TO_REMOTE\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 2:\n", + " action: SSH_TO_REMOTE\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22 # client_2's ip address.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(2)\n", + "client_2.session_manager.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``NODE_SEND_REMOTE_COMMAND``\n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: NODE_SEND_REMOTE_COMMAND\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: NODE_SEND_REMOTE_COMMAND\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " remote_ip: 192.168.10.22\n", + " commands:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(3)\n", + "client_2.file_system.show(full=True)" + ] } ], "metadata": { @@ -216,7 +488,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/src/primaite/simulator/system/core/session_manager.py b/src/primaite/simulator/system/core/session_manager.py index b7e2c021..677ff477 100644 --- a/src/primaite/simulator/system/core/session_manager.py +++ b/src/primaite/simulator/system/core/session_manager.py @@ -413,5 +413,5 @@ class SessionManager: table.align = "l" table.title = f"{self.sys_log.hostname} Session Manager" for session in self.sessions_by_key.values(): - table.add_row([session.dst_ip_address, session.dst_port.value, session.protocol.name]) + table.add_row([session.with_ip_address, session.dst_port.value, session.protocol.name]) print(table) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 9b88bbe8..dc7da205 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -208,10 +208,10 @@ class Terminal(Service): status="success", data={}, ) - return RequestResponse( - status="failure", - data={}, - ) + return RequestResponse( + status="failure", + data={}, + ) rm.add_request( "send_remote_command", diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index d2ea7202..c4247d6e 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -215,3 +215,4 @@ def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]) game.step() assert client_1.file_system.get_file("folder123", "cat.pdf") is None + client_1.session_manager.show() From 8e6b9f39707e1236f3d5f9e8f85d962b33f0e1d5 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 3 Sep 2024 11:53:23 +0100 Subject: [PATCH 06/72] #2782: added ability to create empty folders + create files with size and types + tests --- src/primaite/game/game.py | 21 ++++++++++++++--- .../configs/nodes_with_initial_files.yaml | 9 +++++--- .../test_node_file_system_config.py | 23 ++++++++++++++++--- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index befa4032..d11f6a19 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -17,6 +17,7 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT +from primaite.simulator.file_system.file_type import FileType from primaite.simulator.network.airspace import AirSpaceFrequency from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -331,9 +332,23 @@ class PrimaiteGame: # handle node file system if node_cfg.get("file_system") is not None and len(node_cfg.get("file_system")) > 0: - for folder in node_cfg.get("file_system"): - for file in node_cfg["file_system"][folder]: - new_node.file_system.create_file(folder_name=folder, file_name=file) + for folder_idx, folder_obj in enumerate(node_cfg.get("file_system")): + # if the folder is not a Dict, create an empty folder + if not isinstance(folder_obj, Dict): + new_node.file_system.create_folder(folder_name=folder_obj) + else: + folder_name = next(iter(folder_obj)) + for file_idx, file_obj in enumerate(node_cfg["file_system"][folder_idx][folder_name]): + if not isinstance(file_obj, Dict): + new_node.file_system.create_file(folder_name=folder_name, file_name=file_obj) + else: + file_name = next(iter(file_obj)) + new_node.file_system.create_file( + folder_name=folder_name, + file_name=file_name, + size=file_obj[file_name].get("size", 0), + file_type=FileType[file_obj[file_name].get("type", "UNKNOWN").upper()], + ) if "users" in node_cfg and new_node.software_manager.software.get("UserManager"): user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa diff --git a/tests/assets/configs/nodes_with_initial_files.yaml b/tests/assets/configs/nodes_with_initial_files.yaml index 3213098b..fad6cffd 100644 --- a/tests/assets/configs/nodes_with_initial_files.yaml +++ b/tests/assets/configs/nodes_with_initial_files.yaml @@ -226,11 +226,14 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 file_system: - downloads: + - empty_folder + - downloads: - "test.txt" - "suh_con.dn" - root: - - "passwords.txt" + - root: + - passwords: + size: 69 + type: TXT # pre installed services and applications - hostname: client_3 type: computer diff --git a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py index 05ef7275..49e90b54 100644 --- a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py @@ -5,6 +5,7 @@ from typing import Union import yaml from primaite.game.game import PrimaiteGame +from primaite.simulator.file_system.file_type import FileType from tests import TEST_ASSETS_ROOT BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml" @@ -42,6 +43,22 @@ def test_node_file_system_from_config(): # web files should not exist assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None - # TODO file sizes and file types - # TODO assert that files and folders created: - # TODO create empty folders + empty_folder = client_2.file_system.get_folder(folder_name="empty_folder") + assert empty_folder + assert len(empty_folder.files) == 0 # should have no files + + password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt") + assert password_file # should exist + assert password_file.file_type is FileType.TXT + assert password_file.size is 69 + + downloads_folder = client_2.file_system.get_folder(folder_name="downloads") + assert downloads_folder # downloads folder should exist + + test_txt = downloads_folder.get_file(file_name="test.txt") + assert test_txt # test.txt should exist + assert test_txt.file_type is FileType.TXT + + unknown_file_type = downloads_folder.get_file(file_name="suh_con.dn") + assert unknown_file_type # unknown_file_type should exist + assert unknown_file_type.file_type is FileType.UNKNOWN From 26a56bf3608d2f0c7930d8e0b6e5faa0830e092f Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 3 Sep 2024 12:37:39 +0100 Subject: [PATCH 07/72] #2782: documentation + adding example to data_manipulation.yaml --- .../nodes/common/common_node_attributes.rst | 33 +++++++++++++++++++ .../_package_data/data_manipulation.yaml | 9 ++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index 7cf11eb4..056422ca 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -54,6 +54,39 @@ Optional. Default value is ``3``. The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``. +``file_system`` +--------------- + +Optional. + +The file system of the node. This configuration allows nodes to be initialised with files and/or folders. + +The file system takes a list of folders and files. + +Example: + +.. code-block:: yaml + + simulation: + network: + nodes: + - hostname: client_1 + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + file_system: + - empty_folder # example of an empty folder + - downloads: + - "test_1.txt" # files in the downloads folder + - "test_2.txt" + - root: + - passwords: # example of file with size and type + size: 69 # size in bytes + type: TXT # See FileType for list of available file types + +List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType` + ``users`` --------- diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 97442903..2d03609a 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -843,7 +843,14 @@ simulation: dns_server: 192.168.1.10 services: - type: FTPServer - + file_system: + - root: + - backup_script.sh: # example file in backup server + size: 400 + type: SH + - downloads: + - "ChromeSetup.exe" # another example file + - "New Folder" # example of an empty folder - hostname: security_suite type: server ip_address: 192.168.1.110 From 8e57e707b3e1d5eec3b53d6deeb90d7b9289338b Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 3 Sep 2024 14:38:19 +0100 Subject: [PATCH 08/72] #2845: Changed to store obs data within AgentHistoryItem --- src/primaite/game/agent/interface.py | 18 ++++++++++++++++-- src/primaite/game/game.py | 1 + src/primaite/session/environment.py | 25 ------------------------- src/primaite/session/io.py | 2 -- 4 files changed, 17 insertions(+), 29 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 14b97821..aac6c05a 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -38,6 +38,9 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} + obs_space_data: Optional[ObsType] = None + """The observation space data for this step.""" + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -169,12 +172,23 @@ class AbstractAgent(ABC): return request def process_action_response( - self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + self, + timestep: int, + action: str, + parameters: Dict[str, Any], + request: RequestFormat, + response: RequestResponse, + obs_space_data: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( AgentHistoryItem( - timestep=timestep, action=action, parameters=parameters, request=request, response=response + timestep=timestep, + action=action, + parameters=parameters, + request=request, + response=response, + obs_space_data=obs_space_data, ) ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..ed3c84d3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -186,6 +186,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, + obs_space_data=obs, ) def pre_timestep(self) -> None: diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 23b86546..c66663e3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -112,9 +112,6 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.update_agents(state) next_obs = self._get_obs() # this doesn't update observation, just gets the current observation - if self.io.settings.obs_space_data: - # Write unflattened observation space to log file. - self._write_obs_space_data(self.agent.observation_manager.current_observation) reward = self.agent.reward_function.current_reward _LOGGER.debug(f"step: {self.game.step_counter}, Blue reward: {reward}") terminated = False @@ -142,25 +139,6 @@ class PrimaiteGymEnv(gymnasium.Env): with open(path, "w") as file: json.dump(data, file) - def _write_obs_space_data(self, obs_space: ObsType) -> None: - """Write the unflattened observation space data to a JSON file. - - :param obs: Observation of the environment (dict) - :type obs: ObsType - """ - output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "obs_space_data" - - output_dir.mkdir(parents=True, exist_ok=True) - path = output_dir / f"step_{self.game.step_counter}.json" - - data = { - "episode": self.episode_counter, - "step": self.game.step_counter, - "obs_space_data": obs_space, - } - with open(path, "w") as file: - json.dump(data, file) - def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" _LOGGER.info( @@ -181,9 +159,6 @@ class PrimaiteGymEnv(gymnasium.Env): state = self.game.get_sim_state() self.game.update_agents(state=state) next_obs = self._get_obs() - if self.io.settings.obs_space_data: - # Write unflattened observation space to log file. - self._write_obs_space_data(self.agent.observation_manager.current_observation) info = {} return next_obs, info diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 3627e9e9..78d7cb3c 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -45,8 +45,6 @@ class PrimaiteIO: """The level of sys logs that should be included in the logfiles/logged into terminal.""" agent_log_level: LogLevel = LogLevel.INFO """The level of agent logs that should be included in the logfiles/logged into terminal.""" - obs_space_data: bool = False - """Whether to save observation space data to a log file.""" def __init__(self, settings: Optional[Settings] = None) -> None: """ From 61add769c46b6d8a4f255e301e9d19f5d6a7ddfb Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 3 Sep 2024 17:16:48 +0100 Subject: [PATCH 09/72] #2845: Add test for obs_data_space capture. --- .../observations/test_obs_data_capture.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/integration_tests/game_layer/observations/test_obs_data_capture.py diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py new file mode 100644 index 00000000..205341d9 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -0,0 +1,25 @@ +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.io import PrimaiteIO +import json +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + +def test_obs_data_in_log_file(): + """Create a log file of AgentHistoryItems and check observation data is + included. Assumes that data_manipulation.yaml has an agent labelled + 'defender' with a non-null observation space. + The log file will be in: + primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions + """ + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.reset() + for _ in range(10): + env.step(0) + env.reset() + io = PrimaiteIO() + path = io.generate_agent_actions_save_path(episode=1) + with open(path, 'r') as f: + j = json.load(f) + + assert type(j['0']['defender']['obs_space_data']) == dict From 1822e85eec61710c69db4deaeaeaba2d49053a83 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 3 Sep 2024 17:24:21 +0100 Subject: [PATCH 10/72] #2845: Pre-commit fixes --- .../game_layer/observations/test_obs_data_capture.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py index 205341d9..810b2ad7 100644 --- a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -1,12 +1,15 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json + from primaite.session.environment import PrimaiteGymEnv from primaite.session.io import PrimaiteIO -import json from tests import TEST_ASSETS_ROOT DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + def test_obs_data_in_log_file(): - """Create a log file of AgentHistoryItems and check observation data is + """Create a log file of AgentHistoryItems and check observation data is included. Assumes that data_manipulation.yaml has an agent labelled 'defender' with a non-null observation space. The log file will be in: @@ -19,7 +22,7 @@ def test_obs_data_in_log_file(): env.reset() io = PrimaiteIO() path = io.generate_agent_actions_save_path(episode=1) - with open(path, 'r') as f: + with open(path, "r") as f: j = json.load(f) - assert type(j['0']['defender']['obs_space_data']) == dict + assert type(j["0"]["defender"]["obs_space_data"]) == dict From f4b1d9a91c5566ca6ba49056479d0e8c21f38abe Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 3 Sep 2024 17:26:01 +0100 Subject: [PATCH 11/72] #2845: Update CHANGELOG. --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d08974c..e2989247 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Log observation space data by episode and step. + +## [3.3.0] - 2024-09-04 +### Added - Random Number Generator Seeding by specifying a random number seed in the config file. - Implemented Terminal service class, providing a generic terminal simulation. - Added `User`, `UserManager` and `UserSessionManager` to enable the creation of user accounts and login on Nodes. From 1374a23e14fb9fea35c346747eb8d9edd303c2ca Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 4 Sep 2024 10:17:33 +0100 Subject: [PATCH 12/72] #2782: fix spacing in data_manipulation yaml + documentation --- .../simulation/nodes/common/common_node_attributes.rst | 4 ++-- src/primaite/config/_package_data/data_manipulation.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index 056422ca..c94344fd 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -82,8 +82,8 @@ Example: - "test_2.txt" - root: - passwords: # example of file with size and type - size: 69 # size in bytes - type: TXT # See FileType for list of available file types + size: 69 # size in bytes + type: TXT # See FileType for list of available file types List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType` diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 2d03609a..b36ec707 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -846,8 +846,8 @@ simulation: file_system: - root: - backup_script.sh: # example file in backup server - size: 400 - type: SH + size: 400 + type: SH - downloads: - "ChromeSetup.exe" # another example file - "New Folder" # example of an empty folder From 5608ad5ed5799d0dfb02d5767a5fde0f343ff0e7 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Wed, 4 Sep 2024 14:25:08 +0100 Subject: [PATCH 13/72] #2845: Change 'obs_space_data' to 'observation'. --- src/primaite/game/agent/interface.py | 6 +++--- src/primaite/game/game.py | 2 +- .../game_layer/observations/test_obs_data_capture.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index aac6c05a..d5165a71 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -38,7 +38,7 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} - obs_space_data: Optional[ObsType] = None + observation: Optional[ObsType] = None """The observation space data for this step.""" @@ -178,7 +178,7 @@ class AbstractAgent(ABC): parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse, - obs_space_data: ObsType, + observation: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( @@ -188,7 +188,7 @@ class AbstractAgent(ABC): parameters=parameters, request=request, response=response, - obs_space_data=obs_space_data, + observation=observation, ) ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index ed3c84d3..4f21120d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -186,7 +186,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, - obs_space_data=obs, + observation=obs, ) def pre_timestep(self) -> None: diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py index 810b2ad7..e8bdea22 100644 --- a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -25,4 +25,4 @@ def test_obs_data_in_log_file(): with open(path, "r") as f: j = json.load(f) - assert type(j["0"]["defender"]["obs_space_data"]) == dict + assert type(j["0"]["defender"]["observation"]) == dict From 2391c485698a645a035333208252c39209c1a9da Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 5 Sep 2024 10:18:35 +0100 Subject: [PATCH 14/72] #2782: apply suggestions --- src/primaite/config/_package_data/data_manipulation.yaml | 8 -------- src/primaite/game/game.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index b36ec707..2a069971 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -843,14 +843,6 @@ simulation: dns_server: 192.168.1.10 services: - type: FTPServer - file_system: - - root: - - backup_script.sh: # example file in backup server - size: 400 - type: SH - - downloads: - - "ChromeSetup.exe" # another example file - - "New Folder" # example of an empty folder - hostname: security_suite type: server ip_address: 192.168.1.110 diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index d11f6a19..8e4d4513 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -331,7 +331,7 @@ class PrimaiteGame: raise ValueError(msg) # handle node file system - if node_cfg.get("file_system") is not None and len(node_cfg.get("file_system")) > 0: + if node_cfg.get("file_system"): for folder_idx, folder_obj in enumerate(node_cfg.get("file_system")): # if the folder is not a Dict, create an empty folder if not isinstance(folder_obj, Dict): From e809d89c30d3ba438d4edabfe88ea9c1ba9f226d Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 5 Sep 2024 13:47:59 +0100 Subject: [PATCH 15/72] #2842 and #2843: implement add user and disable user actions + tests --- src/primaite/game/agent/actions.py | 34 +++++++ .../simulator/network/hardware/base.py | 16 +++- tests/conftest.py | 2 + .../actions/test_user_account_actions.py | 93 +++++++++++++++++++ 4 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 tests/integration_tests/game_layer/actions/test_user_account_actions.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 2e6189c0..a299788e 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -1116,6 +1116,38 @@ class ConfigureC2BeaconAction(AbstractAction): return ["network", "node", node_name, "application", "C2Beacon", "configure", config.__dict__] +class NodeAccountsAddUserAction(AbstractAction): + """Action which changes adds a User.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: str, username: str, password: str, is_admin: bool) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return ["network", "node", node_name, "service", "UserManager", "add_user", username, password, is_admin] + + +class NodeAccountsDisableUserAction(AbstractAction): + """Action which disables a user.""" + + def __init__(self, manager: "ActionManager", **kwargs) -> None: + super().__init__(manager=manager) + + def form_request(self, node_id: str, username: str) -> RequestFormat: + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + node_name = self.manager.get_node_name_by_idx(node_id) + return [ + "network", + "node", + node_name, + "service", + "UserManager", + "disable_user", + username, + ] + + class NodeAccountsChangePasswordAction(AbstractAction): """Action which changes the password for a user.""" @@ -1368,6 +1400,8 @@ class ActionManager: "C2_SERVER_RANSOMWARE_CONFIGURE": RansomwareConfigureC2ServerAction, "C2_SERVER_TERMINAL_COMMAND": TerminalC2ServerAction, "C2_SERVER_DATA_EXFILTRATE": ExfiltrationC2ServerAction, + "NODE_ACCOUNTS_ADD_USER": NodeAccountsAddUserAction, + "NODE_ACCOUNTS_DISABLE_USER": NodeAccountsDisableUserAction, "NODE_ACCOUNTS_CHANGE_PASSWORD": NodeAccountsChangePasswordAction, "SSH_TO_REMOTE": NodeSessionsRemoteLoginAction, "SESSIONS_REMOTE_LOGOFF": NodeSessionsRemoteLogoutAction, diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index ef2d47c3..f49d0a17 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -857,7 +857,21 @@ class UserManager(Service): """ rm = super()._init_request_manager() - # todo add doc about requeest schemas + # todo add doc about request schemas + rm.add_request( + "add_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.add_user(username=request[0], password=request[1], is_admin=request[2]) + ) + ), + ) + rm.add_request( + "disable_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0])) + ), + ) rm.add_request( "change_password", RequestType( diff --git a/tests/conftest.py b/tests/conftest.py index 1bbff8f2..50877378 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -463,6 +463,8 @@ def game_and_agent(): {"type": "C2_SERVER_RANSOMWARE_CONFIGURE"}, {"type": "C2_SERVER_TERMINAL_COMMAND"}, {"type": "C2_SERVER_DATA_EXFILTRATE"}, + {"type": "NODE_ACCOUNTS_ADD_USER"}, + {"type": "NODE_ACCOUNTS_DISABLE_USER"}, {"type": "NODE_ACCOUNTS_CHANGE_PASSWORD"}, {"type": "SSH_TO_REMOTE"}, {"type": "SESSIONS_REMOTE_LOGOFF"}, diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py new file mode 100644 index 00000000..fd720315 --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -0,0 +1,93 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.hardware.nodes.host.computer import Computer + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_user_account_add_user_action(game_and_agent_fixture): + """Tests the add user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert len(client_1.user_manager.users) == 1 # admin is created by default + assert len(client_1.user_manager.admins) == 1 + + # add admin account + action = ( + "NODE_ACCOUNTS_ADD_USER", + {"node_id": 0, "username": "soccon_diiz", "password": "nuts", "is_admin": True}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + # add non admin account + action = ( + "NODE_ACCOUNTS_ADD_USER", + {"node_id": 0, "username": "mike_rotch", "password": "password", "is_admin": False}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 3 # new user added + assert len(client_1.user_manager.admins) == 2 + + +def test_user_account_disable_user_action(game_and_agent_fixture): + """Tests the disable user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="icles", is_admin=True) + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + test_user = client_1.user_manager.users.get("test") + assert test_user + assert test_user.disabled is not True + + # disable test account + action = ( + "NODE_ACCOUNTS_DISABLE_USER", + { + "node_id": 0, + "username": "test", + }, + ) + agent.store_action(action) + game.step() + assert test_user.disabled + + +def test_user_account_change_password_action(game_and_agent_fixture): + """Tests the change password user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="icles", is_admin=True) + + test_user = client_1.user_manager.users.get("test") + assert test_user.password == "icles" + + # change account password + action = ( + "NODE_ACCOUNTS_CHANGE_PASSWORD", + {"node_id": 0, "username": "test", "current_password": "icles", "new_password": "2Hard_2_Hack"}, + ) + agent.store_action(action) + game.step() + + assert test_user.password == "2Hard_2_Hack" From a998b8e22b2fd605583fcc6f455894563c5c4ad5 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 5 Sep 2024 16:47:17 +0100 Subject: [PATCH 16/72] #2345: remove try catch + todo - figure out why db connection cannot be made --- .../system/services/web_server/web_server.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 4fc64e1f..f9f561df 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -108,37 +108,43 @@ class WebServer(Service): :type: payload: HttpRequestPacket """ response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload) - try: - parsed_url = urlparse(payload.request_url) - path = parsed_url.path.strip("/") - if len(path) < 1: + parsed_url = urlparse(payload.request_url) + path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else "" + + if len(path) < 1: + # query succeeded + response.status_code = HttpStatusCode.OK + + if path.startswith("users"): + # get data from DatabaseServer + # get all users + if self._establish_db_connection(): + # unable to create a db connection + response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + + if self.db_connection.query("SELECT"): # query succeeded + self.set_health_state(SoftwareHealthState.GOOD) response.status_code = HttpStatusCode.OK + else: + self.set_health_state(SoftwareHealthState.COMPROMISED) + return response - if path.startswith("users"): - # get data from DatabaseServer - # get all users - if not self.db_connection: - self._establish_db_connection() - - if self.db_connection.query("SELECT"): - # query succeeded - self.set_health_state(SoftwareHealthState.GOOD) - response.status_code = HttpStatusCode.OK - else: - self.set_health_state(SoftwareHealthState.COMPROMISED) - - return response - except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) - # something went wrong on the server - response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR - return response - - def _establish_db_connection(self) -> None: + def _establish_db_connection(self) -> bool: """Establish a connection to db.""" + # if active db connection, return true + if self.db_connection: + return True + + # otherwise, try to create db connection db_client = self.software_manager.software.get("DatabaseClient") + + if db_client is None: + return False # database client not installed + self.db_connection: DatabaseClientConnection = db_client.get_new_connection() + return self.db_connection is not None def send( self, From 974aee90b37afd3be0cfddb159cddd63892d2bb4 Mon Sep 17 00:00:00 2001 From: "Archer.Bowen" Date: Fri, 6 Sep 2024 14:09:30 +0100 Subject: [PATCH 17/72] #2842 Added additional tests to confirm terminal functionality --- .../actions/test_user_account_actions.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py index fd720315..bb36ce73 100644 --- a/tests/integration_tests/game_layer/actions/test_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -2,6 +2,8 @@ import pytest from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.transmission.transport_layer import Port @pytest.fixture @@ -91,3 +93,84 @@ def test_user_account_change_password_action(game_and_agent_fixture): game.step() assert test_user.password == "2Hard_2_Hack" + + +def test_user_account_create_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to create new users.""" + game, agent = game_and_agent_fixture + + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Create a new user account via terminal. + action = ( + "NODE_SEND_REMOTE_COMMAND", + { + "node_id": 0, + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "UserManager", "add_user", "new_user", "new_pass", True], + }, + ) + agent.store_action(action) + game.step() + new_user = server_1.user_manager.users.get("new_user") + assert new_user + assert new_user.password == "new_pass" + assert new_user.disabled is not True + + +def test_user_account_disable_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to disable users.""" + game, agent = game_and_agent_fixture + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["UserManager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "SSH_TO_REMOTE", + { + "node_id": 0, + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Disable a user via terminal + action = ( + "NODE_SEND_REMOTE_COMMAND", + { + "node_id": 0, + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "UserManager", "disable_user", "user123"], + }, + ) + agent.store_action(action) + game.step() + + new_user = server_1.user_manager.users.get("user123") + assert new_user + assert new_user.disabled is True From 5ab42ead273934a3132cf47c92cb784a0ccd27bb Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 9 Sep 2024 09:12:20 +0100 Subject: [PATCH 18/72] #2829: Add check for capture_nmne --- src/primaite/game/agent/observations/nic_observations.py | 7 +++++-- src/primaite/game/game.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..c5da8767 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,18 +1,21 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import ClassVar, Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import NMNEConfig from primaite.simulator.network.transmission.transport_layer import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne + "A dataclass defining malicious network events to be captured." class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" @@ -164,7 +167,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): for port in self.monitored_traffic[protocol]: obs["TRAFFIC"][protocol][Port[port].value] = {"inbound": 0, "outbound": 0} - if self.include_nmne: + if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 045b2467..9c0f49af 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -10,6 +10,7 @@ from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations import NICObservation from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent @@ -275,6 +276,7 @@ class PrimaiteGame: links_cfg = network_config.get("links", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) + NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"] From 4a48a8d0547f00f186617cd4226d3853fd0e2be3 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 9 Sep 2024 10:54:42 +0100 Subject: [PATCH 19/72] #2345: return error if db connection cannot be made --- .../simulator/system/services/web_server/web_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index f9f561df..0df47999 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -119,9 +119,10 @@ class WebServer(Service): if path.startswith("users"): # get data from DatabaseServer # get all users - if self._establish_db_connection(): + if not self._establish_db_connection(): # unable to create a db connection response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + return response if self.db_connection.query("SELECT"): # query succeeded From 3cecf169bafab4b059ee27e3156f365e1bb9f3c9 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 9 Sep 2024 16:30:36 +0100 Subject: [PATCH 20/72] #2829: Update and add nmne tests --- .../agent/observations/nic_observations.py | 3 ++- src/primaite/game/game.py | 2 +- .../observations/test_nic_observations.py | 8 +++++++ .../network/test_capture_nmne.py | 22 +++++++++++++++++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index c5da8767..ed2bb7f9 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -14,8 +14,9 @@ from primaite.simulator.network.transmission.transport_layer import Port class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne - "A dataclass defining malicious network events to be captured." + "A Boolean specifying whether malicious network events should be captured." class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 9afdbea6..64cdf63b 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent -from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.observations import NICObservation +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction, SharedReward from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..ced598f0 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -77,6 +77,14 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index debf5b1c..1499df9a 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,5 +1,11 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from itertools import product + +import yaml + +from primaite.config.load import data_manipulation_config_path from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server @@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network): assert web_nic_obs["outbound"] == expected_nmne assert db_nic_obs["inbound"] == expected_nmne uc2_network.apply_timestep(timestep=0) + + +def test_nmne_parameter_settings(): + """ + Check that the four permutations of the values of capture_nmne and + include_nmne work as expected. + """ + + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + DEFENDER = 3 + for capture, include in product([True, False], [True, False]): + cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture + cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include + PrimaiteGymEnv(env_config=cfg) From 82887bdb177258c1d9633b4860833b02c7b640f9 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 10 Sep 2024 10:52:00 +0100 Subject: [PATCH 21/72] #2842: apply PR suggestions --- .../game_layer/actions/test_user_account_actions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py index bb36ce73..2fbf5a8c 100644 --- a/tests/integration_tests/game_layer/actions/test_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -28,7 +28,7 @@ def test_user_account_add_user_action(game_and_agent_fixture): # add admin account action = ( "NODE_ACCOUNTS_ADD_USER", - {"node_id": 0, "username": "soccon_diiz", "password": "nuts", "is_admin": True}, + {"node_id": 0, "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True}, ) agent.store_action(action) game.step() @@ -39,7 +39,7 @@ def test_user_account_add_user_action(game_and_agent_fixture): # add non admin account action = ( "NODE_ACCOUNTS_ADD_USER", - {"node_id": 0, "username": "mike_rotch", "password": "password", "is_admin": False}, + {"node_id": 0, "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False}, ) agent.store_action(action) game.step() @@ -53,7 +53,7 @@ def test_user_account_disable_user_action(game_and_agent_fixture): game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") - client_1.user_manager.add_user(username="test", password="icles", is_admin=True) + client_1.user_manager.add_user(username="test", password="password", is_admin=True) assert len(client_1.user_manager.users) == 2 # new user added assert len(client_1.user_manager.admins) == 2 @@ -79,7 +79,7 @@ def test_user_account_change_password_action(game_and_agent_fixture): game, agent = game_and_agent_fixture client_1 = game.simulation.network.get_node_by_hostname("client_1") - client_1.user_manager.add_user(username="test", password="icles", is_admin=True) + client_1.user_manager.add_user(username="test", password="password", is_admin=True) test_user = client_1.user_manager.users.get("test") assert test_user.password == "icles" @@ -87,7 +87,7 @@ def test_user_account_change_password_action(game_and_agent_fixture): # change account password action = ( "NODE_ACCOUNTS_CHANGE_PASSWORD", - {"node_id": 0, "username": "test", "current_password": "icles", "new_password": "2Hard_2_Hack"}, + {"node_id": 0, "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"}, ) agent.store_action(action) game.step() From 1c6e8b2a95227606ba99a66fb32bf40fe0e1225b Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 10 Sep 2024 11:39:04 +0100 Subject: [PATCH 22/72] #2775 - Removed default ARP rule for routers and added logic when handling ARP traffic --- .../simulator/network/hardware/nodes/network/router.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index ceb91695..bfc90984 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -467,6 +467,14 @@ class AccessControlList(SimComponent): """Check if a packet with the given properties is permitted through the ACL.""" permitted = False rule: ACLRule = None + + # check if the frame is ARP and if ACL rules apply. + if frame.udp: + if frame.is_arp: + permitted = True + rule: ACLRule = None + return permitted, rule + for _rule in self._acl: if not _rule: continue @@ -1257,7 +1265,6 @@ class Router(NetworkNode): Initializes the router's ACL (Access Control List) with default rules, permitting essential protocols like ARP and ICMP, which are necessary for basic network operations and diagnostics. """ - self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) def setup_for_episode(self, episode: int): From 19d6fa2174b0d304ff4abc9a159087383d842c6f Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 11 Sep 2024 10:12:07 +0100 Subject: [PATCH 23/72] #2775 - Updated to look neater --- .../network/hardware/nodes/network/router.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index bfc90984..3b267200 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -468,13 +468,6 @@ class AccessControlList(SimComponent): permitted = False rule: ACLRule = None - # check if the frame is ARP and if ACL rules apply. - if frame.udp: - if frame.is_arp: - permitted = True - rule: ACLRule = None - return permitted, rule - for _rule in self._acl: if not _rule: continue @@ -1376,6 +1369,12 @@ class Router(NetworkNode): return False + def subject_to_acl(self, frame: Frame) -> bool: + """Check that frame is subject to ACL rules.""" + if frame.ip.protocol == IPProtocol.UDP and frame.udp.dst_port == Port.ARP: + return False + return True + def receive_frame(self, frame: Frame, from_network_interface: RouterInterface): """ Processes an incoming frame received on one of the router's interfaces. @@ -1389,8 +1388,12 @@ class Router(NetworkNode): if self.operating_state != NodeOperatingState.ON: return - # Check if it's permitted - permitted, rule = self.acl.is_permitted(frame) + if self.subject_to_acl(frame=frame): + # Check if it's permitted + permitted, rule = self.acl.is_permitted(frame) + else: + permitted = True + rule = None if not permitted: at_port = self._get_port_of_nic(from_network_interface) From a2005df9f0d0d1d3f32051058d5bc3e4b7cfcedf Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 11 Sep 2024 10:56:44 +0100 Subject: [PATCH 24/72] #2775 - Documentation updates --- docs/source/configuration/simulation/nodes/router.rst | 6 +----- docs/source/simulation_components/network/network.rst | 9 +-------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/docs/source/configuration/simulation/nodes/router.rst b/docs/source/configuration/simulation/nodes/router.rst index ac9d6411..b8741521 100644 --- a/docs/source/configuration/simulation/nodes/router.rst +++ b/docs/source/configuration/simulation/nodes/router.rst @@ -74,7 +74,7 @@ The subnet mask setting for the port. ``acl`` ------- -Sets up the ACL rules for the router. +Sets up the ACL rules for the router to apply to layer-3 traffic. These are not applied to layer-2 traffic such as ARP. e.g. @@ -85,10 +85,6 @@ e.g. ... acl: 1: - action: PERMIT - src_port: ARP - dst_port: ARP - 2: action: PERMIT protocol: ICMP diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 636ffbcc..00781307 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -97,17 +97,10 @@ we'll use the following Network that has a client, server, two switches, and a r network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1]) network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1]) -8. Add ACL rules on the Router to allow ARP and ICMP traffic. +8. Add an ACL rules on the Router to allow ICMP traffic. .. code-block:: python - router_1.acl.add_rule( - action=ACLAction.PERMIT, - src_port=Port.ARP, - dst_port=Port.ARP, - position=22 - ) - router_1.acl.add_rule( action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, From d5f1d0fda184aa4cfdf9f7ae1fb032f7073ed145 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 11 Sep 2024 11:28:27 +0100 Subject: [PATCH 25/72] #2775 - Updated Changelog and bring up to date with dev --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2989247..77b7bb7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - Log observation space data by episode and step. +- ACL's are no longer applied to layer-2 traffic. ## [3.3.0] - 2024-09-04 ### Added From f95501f2a872099f616fd5cabfac75d0ddaf4b18 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 11 Sep 2024 15:12:36 +0100 Subject: [PATCH 26/72] #2775 - Purging of more instances where the ARP acl rule is set and no longer necessary. Added a new test to show ARP is unaffected by ACL rules and actioned review comments --- .../simulation_components/network/network.rst | 2 +- .../network/nodes/wireless_router.rst | 1 - src/primaite/simulator/network/creation.py | 2 -- .../network/hardware/nodes/network/router.py | 3 +++ src/primaite/simulator/network/networks.py | 4 ---- .../network/transmission/data_link_layer.py | 2 +- tests/conftest.py | 3 --- .../game_layer/test_actions.py | 14 ++++++------- .../integration_tests/network/test_routing.py | 1 - .../network/test_wireless_router.py | 1 - tests/integration_tests/system/test_arp.py | 20 ++++++++++++++++++- .../_system/_services/test_terminal.py | 1 - 12 files changed, 31 insertions(+), 23 deletions(-) diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 00781307..b04d6ecf 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -97,7 +97,7 @@ we'll use the following Network that has a client, server, two switches, and a r network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1]) network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1]) -8. Add an ACL rules on the Router to allow ICMP traffic. +8. Add an ACL rule on the Router to allow ICMP traffic. .. code-block:: python diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index c78c8419..02fe73db 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -102,7 +102,6 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) # Configure PC B diff --git a/src/primaite/simulator/network/creation.py b/src/primaite/simulator/network/creation.py index 61a37a90..b801a38e 100644 --- a/src/primaite/simulator/network/creation.py +++ b/src/primaite/simulator/network/creation.py @@ -7,7 +7,6 @@ from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port def num_of_switches_required(num_nodes: int, max_network_interface: int = 24) -> int: @@ -98,7 +97,6 @@ def create_office_lan( default_gateway = IPv4Address(f"192.168.{subnet_base}.1") router = Router(hostname=f"router_{lan_name}", start_up_duration=0) router.power_on() - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) network.add_node(router) router.configure_port(port=1, ip_address=default_gateway, subnet_mask="255.255.255.0") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 3b267200..e86b1843 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1388,6 +1388,9 @@ class Router(NetworkNode): if self.operating_state != NodeOperatingState.ON: return + print("£££££££££££££££££££££££££££££") + print(f"Frame received is: {frame}") + if self.subject_to_acl(frame=frame): # Check if it's permitted permitted, rule = self.acl.is_permitted(frame) diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index cb0965eb..ae6476c1 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -79,8 +79,6 @@ def client_server_routed() -> Network: server_1.power_on() network.connect(endpoint_b=server_1.network_interface[1], endpoint_a=switch_1.network_interface[1]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) return network @@ -271,8 +269,6 @@ def arcd_uc2_network() -> Network: security_suite.connect_nic(NIC(ip_address="192.168.10.110", subnet_mask="255.255.255.0")) network.connect(endpoint_b=security_suite.network_interface[2], endpoint_a=switch_2.network_interface[7]) - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) # Allow PostgreSQL requests diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 159eca7f..9d8a0a1c 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -161,7 +161,7 @@ class Frame(BaseModel): """ Checks if the Frame is an ARP (Address Resolution Protocol) packet. - This is determined by checking if the destination port of the TCP header is equal to the ARP port. + This is determined by checking if the destination port of the UDP header is equal to the ARP port. :return: True if the Frame is an ARP packet, otherwise False. """ diff --git a/tests/conftest.py b/tests/conftest.py index 1bbff8f2..e9aeada8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -350,7 +350,6 @@ def install_stuff_to_sim(sim: Simulation): network.connect(endpoint_a=server_2.network_interface[1], endpoint_b=switch_2.network_interface[2]) # 2: Configure base ACL - router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1) router.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) @@ -382,8 +381,6 @@ def install_stuff_to_sim(sim: Simulation): assert acl_rule.src_port == acl_rule.dst_port == Port.DNS elif i == 3: assert acl_rule.src_port == acl_rule.dst_port == Port.HTTP - elif i == 22: - assert acl_rule.src_port == acl_rule.dst_port == Port.ARP elif i == 23: assert acl_rule.protocol == IPProtocol.ICMP elif i == 24: diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index a1005f34..ecd21a03 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -115,7 +115,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox server_1 = game.simulation.network.get_node_by_hostname("server_1") server_2 = game.simulation.network.get_node_by_hostname("server_2") router = game.simulation.network.get_node_by_hostname("router") - assert router.acl.num_rules == 4 + assert router.acl.num_rules == 3 assert client_1.ping("10.0.2.3") # client_1 can ping server_2 assert server_2.ping("10.0.1.2") # server_2 can ping client_1 @@ -138,8 +138,8 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox agent.store_action(action) game.step() - # 3: Check that the ACL now has 5 rules, and that client 1 cannot ping server 2 - assert router.acl.num_rules == 5 + # 3: Check that the ACL now has 4 rules, and that client 1 cannot ping server 2 + assert router.acl.num_rules == 4 assert not client_1.ping("10.0.2.3") # Cannot ping server_2 assert client_1.ping("10.0.2.2") # Can ping server_1 assert not server_2.ping( @@ -165,8 +165,8 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox agent.store_action(action) game.step() - # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 - assert router.acl.num_rules == 6 + # 5: Check that the ACL now has 5 rules, but that server_1 can still ping server_2 + assert router.acl.num_rules == 5 assert server_1.ping("10.0.2.3") # Can ping server_2 @@ -195,8 +195,8 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P agent.store_action(action) game.step() - # 3: Check that the ACL now has 3 rules, and that client 1 cannot access example.com - assert router.acl.num_rules == 3 + # 3: Check that the ACL now has 2 rules, and that client 1 cannot access example.com + assert router.acl.num_rules == 2 assert not browser.get_webpage() client_1.software_manager.software.get("DNSClient").dns_cache.clear() assert client_1.ping("10.0.2.2") # pinging still works because ICMP is allowed diff --git a/tests/integration_tests/network/test_routing.py b/tests/integration_tests/network/test_routing.py index 62b58cbd..e234b4e5 100644 --- a/tests/integration_tests/network/test_routing.py +++ b/tests/integration_tests/network/test_routing.py @@ -73,7 +73,6 @@ def multi_hop_network() -> Network: router_1.enable_port(2) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) # Configure PC B diff --git a/tests/integration_tests/network/test_wireless_router.py b/tests/integration_tests/network/test_wireless_router.py index 733de6f6..9a22208b 100644 --- a/tests/integration_tests/network/test_wireless_router.py +++ b/tests/integration_tests/network/test_wireless_router.py @@ -37,7 +37,6 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) # Configure PC B diff --git a/tests/integration_tests/system/test_arp.py b/tests/integration_tests/system/test_arp.py index be8656aa..6c7e853a 100644 --- a/tests/integration_tests/system/test_arp.py +++ b/tests/integration_tests/system/test_arp.py @@ -1,5 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.network.router import RouterARP +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.services.arp.arp import ARP from tests.integration_tests.network.test_routing import multi_hop_network @@ -48,3 +50,19 @@ def test_arp_fails_for_network_address_between_routers(multi_hop_network): actual_result = router_1_arp.get_arp_cache_mac_address(router_1.network_interface[1].ip_network.network_address) assert actual_result == expected_result + + +def test_arp_not_affected_by_acl(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + router_1: Router = multi_hop_network.get_node_by_hostname("router_1") + + # Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic + # as it operates a different layer within the network. + router_1.acl.add_rule(action=ACLAction.DENY, src_port=Port.ARP, dst_port=Port.ARP, position=23) + + pc_a_arp: ARP = pc_a.software_manager.arp + + expected_result = router_1.network_interface[2].mac_address + actual_result = pc_a_arp.get_arp_cache_mac_address(router_1.network_interface[2].ip_address) + + assert actual_result == expected_result diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 41858b90..3c3daa61 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -77,7 +77,6 @@ def wireless_wan_network(): network.connect(pc_a.network_interface[1], router_1.network_interface[2]) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) # add ACL rule to allow SSH traffic From 85863b1972516a9c40b441e23ff2865bd82fe437 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 11 Sep 2024 15:36:51 +0100 Subject: [PATCH 27/72] #2775 - Removed a print statement committed in error and updated the checks done in subject_to_acl following review --- .../simulator/network/hardware/nodes/network/router.py | 5 +---- .../simulator/network/transmission/data_link_layer.py | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index e86b1843..8cdf3f86 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -1371,7 +1371,7 @@ class Router(NetworkNode): def subject_to_acl(self, frame: Frame) -> bool: """Check that frame is subject to ACL rules.""" - if frame.ip.protocol == IPProtocol.UDP and frame.udp.dst_port == Port.ARP: + if frame.ip.protocol == IPProtocol.UDP and frame.is_arp: return False return True @@ -1388,9 +1388,6 @@ class Router(NetworkNode): if self.operating_state != NodeOperatingState.ON: return - print("£££££££££££££££££££££££££££££") - print(f"Frame received is: {frame}") - if self.subject_to_acl(frame=frame): # Check if it's permitted permitted, rule = self.acl.is_permitted(frame) diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index 9d8a0a1c..86a6038b 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -161,11 +161,11 @@ class Frame(BaseModel): """ Checks if the Frame is an ARP (Address Resolution Protocol) packet. - This is determined by checking if the destination port of the UDP header is equal to the ARP port. + This is determined by checking if the destination and source port of the UDP header is equal to the ARP port. :return: True if the Frame is an ARP packet, otherwise False. """ - return self.udp.dst_port == Port.ARP + return self.udp.dst_port == Port.ARP and self.udp.src_port == Port.ARP @property def is_icmp(self) -> bool: From f908f9b23e43f910ee8b1e732ca546e2d3b954ca Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 11 Sep 2024 15:50:14 +0100 Subject: [PATCH 28/72] #2775 - Actioning review comments --- tests/integration_tests/game_layer/test_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index ecd21a03..a9231632 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -106,7 +106,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox """ Test that the RouterACLAddRuleAction can form a request and that it is accepted by the simulation. - The ACL starts off with 4 rules, and we add a rule, and check that the ACL now has 5 rules. + The ACL starts off with 3 rules, and we add a rule, and check that the ACL now has 4 rules. """ game, agent = game_and_agent From 8bd20275d085fc79b016e690ebabff0b4d52008f Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 12 Sep 2024 10:01:12 +0100 Subject: [PATCH 29/72] #2842: fix test --- .../game_layer/actions/test_user_account_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py index 2fbf5a8c..f97716c6 100644 --- a/tests/integration_tests/game_layer/actions/test_user_account_actions.py +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -82,7 +82,7 @@ def test_user_account_change_password_action(game_and_agent_fixture): client_1.user_manager.add_user(username="test", password="password", is_admin=True) test_user = client_1.user_manager.users.get("test") - assert test_user.password == "icles" + assert test_user.password == "password" # change account password action = ( From 7c26ca9d79d14bc529368f455909a097c8a4614c Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 12 Sep 2024 16:07:14 +0100 Subject: [PATCH 30/72] #2864: add configuration for services_requires_scan and applications_requires_scan --- .../agent/observations/host_observations.py | 24 +++- .../agent/observations/node_observations.py | 12 +- .../observations/software_observation.py | 33 ++++-- .../_game/_agent/test_observations.py | 111 +++++++++++++++++- 4 files changed, 169 insertions(+), 11 deletions(-) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 4419ccc7..3371a99c 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -52,6 +52,14 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ + services_requires_scan: Optional[bool] = None + """ + If True, services must be scanned to update the health state. If False, true state is always shown. + """ + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ include_users: Optional[bool] = True """If True, report user session information.""" @@ -71,6 +79,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + services_requires_scan: bool, + applications_requires_scan: bool, include_users: bool, ) -> None: """ @@ -106,6 +116,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param services_requires_scan: If True, services must be scanned to update the health state. + If False, the true state is always shown. + :type services_requires_scan: bool + :param applications_requires_scan: If True, applications must be scanned to update the health state. + If False, the true state is always shown. + :type applications_requires_scan: bool :param include_users: If True, report user session information. :type include_users: bool """ @@ -119,7 +135,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: - self.services.append(ServiceObservation(where=None)) + self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan)) while len(self.services) > num_services: truncated_service = self.services.pop() msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" @@ -127,7 +143,9 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.applications: List[ApplicationObservation] = applications while len(self.applications) < num_applications: - self.applications.append(ApplicationObservation(where=None)) + self.applications.append( + ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan) + ) while len(self.applications) > num_applications: truncated_application = self.applications.pop() msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" @@ -293,5 +311,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + services_requires_scan=config.services_requires_scan, + applications_requires_scan=config.applications_requires_scan, include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index e263cadb..85de5396 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -45,7 +45,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"): include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" file_system_requires_scan: bool = True - """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + """If True, the folder must be scanned to update the health state. If False, the true state is always shown.""" + services_requires_scan: bool = True + """If True, the services must be scanned to update the health state. + If False, the true state is always shown.""" + applications_requires_scan: bool = True + """If True, the applications must be scanned to update the health state. + If False, the true state is always shown.""" include_users: Optional[bool] = True """If True, report user session information.""" num_ports: Optional[int] = None @@ -193,6 +199,10 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.services_requires_scan is None: + host_config.services_requires_scan = config.services_requires_scan + if host_config.applications_requires_scan is None: + host_config.applications_requires_scan = config.applications_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 15cd2447..2075ce43 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict +from typing import Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): service_name: str """Name of the service, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + services_requires_scan: Optional[bool] = None + """If True, services must be scanned to update the health state. If False, true state is always shown.""" + + def __init__(self, where: WhereType, services_requires_scan: bool) -> None: """ Initialise a service observation instance. @@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): :type where: WhereType """ self.where = where + self.services_requires_scan = services_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0} def observe(self, state: Dict) -> ObsType: @@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): return self.default_observation return { "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], + "health_status": service_state["health_state_visible"] + if self.services_requires_scan + else service_state["health_state_actual"], } @property @@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): :return: Constructed service observation instance. :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", config.service_name]) + return cls( + where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan + ) class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): @@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): application_name: str """Name of the application, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + + def __init__(self, where: WhereType, applications_requires_scan: bool) -> None: """ Initialise an application observation instance. @@ -92,6 +105,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): :type where: WhereType """ self.where = where + self.applications_requires_scan = applications_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} # TODO: allow these to be configured in yaml @@ -128,7 +142,9 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): return self.default_observation return { "operating_status": application_state["operating_state"], - "health_status": application_state["health_state_visible"], + "health_status": application_state["health_state_visible"] + if self.applications_requires_scan + else application_state["health_state_actual"], "num_executions": self._categorise_num_executions(application_state["num_executions"]), } @@ -161,4 +177,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["applications", config.application_name]) + return cls( + where=parent_where + ["applications", config.application_name], + applications_requires_scan=config.applications_requires_scan, + ) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 7f590685..583b9cbd 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -4,7 +4,7 @@ from typing import List import pytest import yaml -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations import ObservationManager, ServiceObservation from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation from primaite.game.agent.observations.host_observations import HostObservation @@ -130,3 +130,112 @@ class TestFileSystemRequiresScan: [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=False ) assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3 + + +class TestServiceRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("services_requires_scan: true", True), + ("services_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set FileSystemRequiresScan to True.""" + obs_cfg_yaml = f""" + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + applications: + - application_name: WebBrowser + - hostname: client_2 + num_services: 1 + num_applications: 1 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for host in hosts: + services: List[ServiceObservation] = host.services + for service in services: + assert service.services_requires_scan == expected_val # Make sure services require scan by default + + def test_services_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1} + + obs_requiring_scan = ServiceObservation([], services_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value From 1f937a4c961ae77fa11c57136d3366bd8b073439 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 12 Sep 2024 18:54:18 +0100 Subject: [PATCH 31/72] #2864: config not being passed correctly --- .../agent/observations/host_observations.py | 4 ++++ .../observations/test_node_observations.py | 2 ++ .../test_software_observations.py | 8 ++++++-- .../_game/_agent/test_observations.py | 20 +++++++++++-------- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 3371a99c..c05b493a 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -281,6 +281,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): folder_config.file_system_requires_scan = config.file_system_requires_scan for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + for service_config in config.services: + service_config.services_requires_scan = config.services_requires_scan + for application_config in config.applications: + application_config.application_config_requires_scan = config.application_config_requires_scan services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 69d9f106..9d60823b 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,8 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + services_requires_scan=True, + applications_requires_scan=True, include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 998aa755..ab9f6e9c 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,9 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("NTPServer") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + service_obs = ServiceObservation( + where=["network", "nodes", pc.hostname, "services", "NTPServer"], services_requires_scan=True + ) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +56,9 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + app_obs = ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], applications_requires_scan=True + ) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 583b9cbd..912b672e 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -1,4 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json from typing import List import pytest @@ -142,7 +143,7 @@ class TestServiceRequiresScan: ), ) def test_obs_config(self, yaml_option_string, expected_val): - """Check that the default behaviour is to set FileSystemRequiresScan to True.""" + """Check that the default behaviour is to set service_requires_scan to True.""" obs_cfg_yaml = f""" type: CUSTOM options: @@ -155,19 +156,20 @@ class TestServiceRequiresScan: - hostname: web_server services: - service_name: WebServer + - service_name: DNSClient - hostname: database_server folders: - folder_name: database files: - file_name: database.db - hostname: backup_server + services: + - service_name: FTPServer - hostname: security_suite - hostname: client_1 - applications: - - application_name: WebBrowser - hostname: client_2 - num_services: 1 - num_applications: 1 + num_services: 3 + num_applications: 0 num_folders: 1 num_files: 1 num_nics: 2 @@ -226,10 +228,12 @@ class TestServiceRequiresScan: manager = ObservationManager.from_config(cfg) hosts: List[HostObservation] = manager.obs.components["NODES"].hosts - for host in hosts: + for i, host in enumerate(hosts): services: List[ServiceObservation] = host.services - for service in services: - assert service.services_requires_scan == expected_val # Make sure services require scan by default + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure services require scan by default def test_services_requires_scan(self): state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1} From f1ff1f13cf4fc9e4fc8760a51b4d029f1307c8af Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 13 Sep 2024 09:08:44 +0100 Subject: [PATCH 32/72] #2864: added applications_requires_scan test --- .../agent/observations/host_observations.py | 2 +- .../_game/_agent/test_observations.py | 116 +++++++++++++++++- 2 files changed, 115 insertions(+), 3 deletions(-) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index c05b493a..da054eda 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -284,7 +284,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): for service_config in config.services: service_config.services_requires_scan = config.services_requires_scan for application_config in config.applications: - application_config.application_config_requires_scan = config.application_config_requires_scan + application_config.applications_requires_scan = config.applications_requires_scan services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 912b672e..935bbdcf 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -5,7 +5,7 @@ from typing import List import pytest import yaml -from primaite.game.agent.observations import ObservationManager, ServiceObservation +from primaite.game.agent.observations import ApplicationObservation, ObservationManager, ServiceObservation from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation from primaite.game.agent.observations.host_observations import HostObservation @@ -133,7 +133,7 @@ class TestFileSystemRequiresScan: assert obs_not_requiring_scan.observe(folder_state)["health_status"] == 3 -class TestServiceRequiresScan: +class TestServicesRequiresScan: @pytest.mark.parametrize( ("yaml_option_string", "expected_val"), ( @@ -243,3 +243,115 @@ class TestServiceRequiresScan: obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False) assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value + + +class TestApplicationsRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("applications_requires_scan: true", True), + ("applications_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set applications_requires_scan to True.""" + obs_cfg_yaml = f""" + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + applications: + - application_name: WebBrowser + - hostname: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + num_services: 0 + num_applications: 3 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: "NONE" + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure applications require scan by default + + def test_applications_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1, "num_executions": 1} + + obs_requiring_scan = ApplicationObservation([], applications_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ApplicationObservation([], applications_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value From 454789f49461ff285b22d7a5bfbe4890a4c5f335 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 13 Sep 2024 09:34:09 +0100 Subject: [PATCH 33/72] #2864: add to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77b7bb7d..56f0c038 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Log observation space data by episode and step. - ACL's are no longer applied to layer-2 traffic. +- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). ## [3.3.0] - 2024-09-04 ### Added From d8c85058edc2b80e9f9afcede098c9d133804d2f Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Sep 2024 10:32:09 +0100 Subject: [PATCH 34/72] #2456 - Minor change to arp.show() to include port number --- src/primaite/simulator/system/services/arp/arp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index efadf189..9314bea7 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -47,7 +47,7 @@ class ARP(Service): :param markdown: If True, format the output as Markdown. Otherwise, use plain text. """ - table = PrettyTable(["IP Address", "MAC Address", "Via"]) + table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" @@ -58,6 +58,7 @@ class ARP(Service): str(ip), arp.mac_address, self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address, + self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num, ] ) print(table) From 94b30909ee62624181bbaa7187948f41f17b8164 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Sep 2024 10:56:12 +0100 Subject: [PATCH 35/72] #2456 - Updated Changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77b7bb7d..71341a17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - Log observation space data by episode and step. + +### Changed - ACL's are no longer applied to layer-2 traffic. +- ARP .show() method will no include the port number associated with each entry. + ## [3.3.0] - 2024-09-04 ### Added From 9a2fb2a0846cdc8958d7cad03fdbfaa6277e31bc Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 13 Sep 2024 11:11:58 +0100 Subject: [PATCH 36/72] #2880: fix action shape for num_ports + test --- src/primaite/game/agent/actions.py | 4 ++-- .../game_layer/test_action_shapes.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/integration_tests/game_layer/test_action_shapes.py diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index a299788e..c864f75f 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -877,7 +877,7 @@ class FirewallACLRemoveRuleAction(AbstractAction): """Action which removes a rule from a firewall port's ACL.""" def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: - """Init method for RouterACLRemoveRuleAction. + """Init method for FirewallACLRemoveRuleAction. :param manager: Reference to the ActionManager which created this action. :type manager: ActionManager @@ -1524,7 +1524,7 @@ class ActionManager: "num_nics": max_nics_per_node, "num_acl_rules": max_acl_rules, "num_protocols": len(self.protocols), - "num_ports": len(self.protocols), + "num_ports": len(self.ports), "num_ips": len(self.ip_address_list), "max_acl_rules": max_acl_rules, "max_nics_per_node": max_nics_per_node, diff --git a/tests/integration_tests/game_layer/test_action_shapes.py b/tests/integration_tests/game_layer/test_action_shapes.py new file mode 100644 index 00000000..48500d8f --- /dev/null +++ b/tests/integration_tests/game_layer/test_action_shapes.py @@ -0,0 +1,21 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import Tuple + +from primaite.game.agent.interface import ProxyAgent +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +FIREWALL_ACTIONS_NETWORK = TEST_ASSETS_ROOT / "configs/firewall_actions_network.yaml" + + +def test_router_acl_add_rule_action_shape(game_and_agent: Tuple[PrimaiteGame, ProxyAgent]): + """Test to check ROUTER_ADD_ACL_RULE has the expected action shape.""" + game, agent = game_and_agent + + # assert that the shape of the actions is correct + router_acl_add_rule_action = agent.action_manager.actions.get("ROUTER_ACL_ADDRULE") + assert router_acl_add_rule_action.shape.get("source_ip_id") == len(agent.action_manager.ip_address_list) + assert router_acl_add_rule_action.shape.get("dest_ip_id") == len(agent.action_manager.ip_address_list) + assert router_acl_add_rule_action.shape.get("source_port_id") == len(agent.action_manager.ports) + assert router_acl_add_rule_action.shape.get("dest_port_id") == len(agent.action_manager.ports) + assert router_acl_add_rule_action.shape.get("protocol_id") == len(agent.action_manager.protocols) From 17035be0284f1789e026d4cf8328cc97b2035c8b Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Sep 2024 11:13:55 +0100 Subject: [PATCH 37/72] #2456 - Actioning review comment --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71341a17..53b29e85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - ACL's are no longer applied to layer-2 traffic. -- ARP .show() method will no include the port number associated with each entry. +- ARP .show() method will now include the port number associated with each entry. ## [3.3.0] - 2024-09-04 From c924b9ea46dcd2a701fdf6ba93a3733b72c09a99 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Sep 2024 11:54:17 +0100 Subject: [PATCH 38/72] #2871 - Initial commit of a show_history() function in AbstractAgent --- src/primaite/game/agent/interface.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index d5165a71..404c2bfe 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType +from prettytable import PrettyTable from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager @@ -126,6 +127,27 @@ class AbstractAgent(ABC): self.history: List[AgentHistoryItem] = [] self.logger = AgentLog(agent_name) + def show_history(self): + """ + Print an agent action provided it's not the DONOTHING action. + + :param agent_name: Name of agent (str). + """ + table = PrettyTable() + table.field_names = ["Step", "Action", "Node", "Application", "Response"] + print(f"Actions for '{self.agent_name}':") + for item in self.history: + if item.action != "DONOTHING": + node, application = "unknown", "unknown" + if (node_id := item.parameters.get("node_id")) is not None: + node = self.action_manager.node_names[node_id] + if (application_id := item.parameters.get("application_id")) is not None: + application = self.action_manager.application_names[node_id][application_id] + if (application_name := item.parameters.get("application_name")) is not None: + application = application_name + table.add_row([item.timestep, item.action, node, application, item.response.status]) + print(table) + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. From cd8fc6d42d153b28b5cd731a4fe19239bfcc327d Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 13 Sep 2024 12:10:49 +0100 Subject: [PATCH 39/72] #2879: Handle generate_seed_value option --- src/primaite/session/environment.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index c66663e3..ac9415ac 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -26,14 +26,25 @@ except ModuleNotFoundError: _LOGGER.debug("Torch not available for importing") -def set_random_seed(seed: int) -> Union[None, int]: +def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: """ Set random number generators. + If seed is None or -1 and generate_seed_value is True randomly generate a + seed value. + If seed is > -1 and generate_seed_value is True ignore the latter and use + the provide seed value. + :param seed: int + :param generate_seed_value: bool + :return: None or the int representing the seed used. """ if seed is None or seed == -1: - return None + if generate_seed_value: + rng = np.random.default_rng() + seed = int(rng.integers(low=0, high=2**63)) + else: + return None elif seed < -1: raise ValueError("Invalid random number seed") # Seed python RNG @@ -65,7 +76,8 @@ class PrimaiteGymEnv(gymnasium.Env): """Object that returns a config corresponding to the current episode.""" self.seed = self.episode_scheduler(0).get("game", {}).get("seed") """Get RNG seed from config file. NB: Must be before game instantiation.""" - self.seed = set_random_seed(self.seed) + self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value") + self.seed = set_random_seed(self.seed, self.generate_seed_value) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) From 6ebe50c331725c5059f269a59d87bd1dcd4077b3 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 13 Sep 2024 12:58:37 +0100 Subject: [PATCH 40/72] #2879: Reduce max seed value to comply with python random seed limit --- src/primaite/session/environment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index ac9415ac..0fd21b9f 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -42,7 +42,8 @@ def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: if seed is None or seed == -1: if generate_seed_value: rng = np.random.default_rng() - seed = int(rng.integers(low=0, high=2**63)) + # 2**32-1 is highest value for python RNG seed. + seed = int(rng.integers(low=0, high=2**32-1)) else: return None elif seed < -1: From 08fcf1df19fc811bf9a24aee04d8c5c3239f9678 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 13 Sep 2024 12:59:41 +0100 Subject: [PATCH 41/72] #2879: Add generate_seed_value to global options. --- src/primaite/game/game.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 123b6ddd..0e7b8c23 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -80,6 +80,8 @@ class PrimaiteGameOptions(BaseModel): seed: int = None """Random number seed for RNGs.""" + generate_seed_value: bool = False + """Internally generated seed value.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[str] From f2a0eeaca23159da9caa0cd9e55e81f5aaac6875 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Sep 2024 14:11:13 +0100 Subject: [PATCH 42/72] #2871 - Updated show_history() method to use boolean 'include_nothing' for whether to include DONOTHING actions --- src/primaite/game/agent/interface.py | 31 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 404c2bfe..0ec44d22 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -127,25 +127,34 @@ class AbstractAgent(ABC): self.history: List[AgentHistoryItem] = [] self.logger = AgentLog(agent_name) - def show_history(self): + def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable: + """Update the given table with information from given AgentHistoryItem.""" + node, application = "unknown", "unknown" + if (node_id := item.parameters.get("node_id")) is not None: + node = self.action_manager.node_names[node_id] + if (application_id := item.parameters.get("application_id")) is not None: + application = self.action_manager.application_names[node_id][application_id] + if (application_name := item.parameters.get("application_name")) is not None: + application = application_name + table.add_row([item.timestep, item.action, node, application, item.response.status]) + return table + + def show_history(self, include_nothing: bool = False): """ Print an agent action provided it's not the DONOTHING action. - :param agent_name: Name of agent (str). + :param include_nothing: boolean for including DONOTHING actions. Default False. """ table = PrettyTable() table.field_names = ["Step", "Action", "Node", "Application", "Response"] print(f"Actions for '{self.agent_name}':") for item in self.history: - if item.action != "DONOTHING": - node, application = "unknown", "unknown" - if (node_id := item.parameters.get("node_id")) is not None: - node = self.action_manager.node_names[node_id] - if (application_id := item.parameters.get("application_id")) is not None: - application = self.action_manager.application_names[node_id][application_id] - if (application_name := item.parameters.get("application_name")) is not None: - application = application_name - table.add_row([item.timestep, item.action, node, application, item.response.status]) + if item.action == "DONOTHING": + if include_nothing: + table = self.add_agent_action(item=item, table=table) + else: + pass + self.add_agent_action(item=item, table=table) print(table) def update_observation(self, state: Dict) -> ObsType: From 01a2c834ce3c8ff23c90ff098ef2cce04bdd5bab Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Fri, 13 Sep 2024 14:53:15 +0100 Subject: [PATCH 43/72] #2879: Write seed value to log file. --- src/primaite/session/environment.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 0fd21b9f..9054106e 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -62,6 +62,13 @@ def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: return seed +def log_seed_value(seed: int): + """Log the selected seed value to file.""" + path = SIM_OUTPUT.path / "seed.log" + with open(path, "w") as file: + file.write(f"Seed value = {seed}") + + class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. @@ -92,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env): _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + log_seed_value(self.seed) + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. From 5006e41546d37cabb0a505fe0c9e3346dcaebf89 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Fri, 13 Sep 2024 15:47:59 +0100 Subject: [PATCH 44/72] #2871 - Updated the show_history() function to receive a list of actions to ignore when printing the history. Defaults to ignoring DONOTHING actions --- src/primaite/game/agent/interface.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 0ec44d22..6609dd03 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -139,22 +139,23 @@ class AbstractAgent(ABC): table.add_row([item.timestep, item.action, node, application, item.response.status]) return table - def show_history(self, include_nothing: bool = False): + def show_history(self, ignored_actions: Optional[list] = None): """ Print an agent action provided it's not the DONOTHING action. - :param include_nothing: boolean for including DONOTHING actions. Default False. + :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. + If not provided, defaults to ignore DONOTHING actions. """ + if not ignored_actions: + ignored_actions = ["DONOTHING"] table = PrettyTable() table.field_names = ["Step", "Action", "Node", "Application", "Response"] print(f"Actions for '{self.agent_name}':") for item in self.history: - if item.action == "DONOTHING": - if include_nothing: - table = self.add_agent_action(item=item, table=table) - else: - pass - self.add_agent_action(item=item, table=table) + if item.action in ignored_actions: + pass + else: + table = self.add_agent_action(item=item, table=table) print(table) def update_observation(self, state: Dict) -> ObsType: From e0a10928343c650b986da8aa8cd6207786448e0f Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 16 Sep 2024 09:04:17 +0100 Subject: [PATCH 45/72] #2879: Pre-commit fix. --- src/primaite/session/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 9054106e..07635b70 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -43,7 +43,7 @@ def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: if generate_seed_value: rng = np.random.default_rng() # 2**32-1 is highest value for python RNG seed. - seed = int(rng.integers(low=0, high=2**32-1)) + seed = int(rng.integers(low=0, high=2**32 - 1)) else: return None elif seed < -1: From 215ceaa6e8b5977b231d226715b70d8e88df7f14 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 16 Sep 2024 10:08:45 +0100 Subject: [PATCH 46/72] #2879: Fix call to set_random_seed() in reset(). --- src/primaite/session/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 07635b70..db5425e3 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -168,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if seed is not None: - set_random_seed(seed) + set_random_seed(seed, self.generate_seed_value) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: From f3ca9c55c90fe05b2e43c2c167767037578a0fb7 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Mon, 16 Sep 2024 16:38:19 +0100 Subject: [PATCH 47/72] #2879: Update tests --- .../game_layer/test_RNG_seed.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index 0c6d567d..508f35e6 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -7,6 +7,7 @@ import yaml from primaite.config.load import data_manipulation_config_path from primaite.game.agent.interface import AgentHistoryItem from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator import SIM_OUTPUT @pytest.fixture() @@ -33,6 +34,11 @@ def test_rng_seed_set(create_env): assert a == b + # Check that seed log file was created. + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + assert file + def test_rng_seed_unset(create_env): """Test with no RNG seed.""" @@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env): b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"] assert a != b + + +def test_for_generated_seed(): + """ + Show that setting generate_seed_value to true producess a valid seed. + """ + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + cfg["game"]["generate_seed_value"] = True + PrimaiteGymEnv(env_config=cfg) + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + data = file.read() + + assert data.split(" ")[3] != None From 078b89856535b0071c76921612b6758f6d48782c Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Tue, 17 Sep 2024 09:30:14 +0100 Subject: [PATCH 48/72] #2879: Update changelog. --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77b7bb7d..a9f6c891 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Log observation space data by episode and step. - ACL's are no longer applied to layer-2 traffic. +- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file + or `generate_seed_value` is set to `true`. ## [3.3.0] - 2024-09-04 ### Added From 5d7935cde083d662389198b8345fc9194e8351be Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Sep 2024 09:39:32 +0100 Subject: [PATCH 49/72] #2871 - Changes to notebooks following updates to action history --- .../Command-&-Control-E2E-Demonstration.ipynb | 12 +++++- .../Data-Manipulation-E2E-Demonstration.ipynb | 11 ++++- .../Getting-Information-Out-Of-PrimAITE.ipynb | 40 ++++++++++++++++++- 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index b6b13f28..a0599ee4 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1800,6 +1800,16 @@ "\n", "display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "env.game.agents[\"CustomC2Agent\"].show_history()" + ] } ], "metadata": { @@ -1818,7 +1828,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 0460f771..c1b959f5 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -675,6 +675,15 @@ " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.game.agents[\"data_manipulation_attacker\"].show_history(ignored_actions=[\"\"])" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -708,7 +717,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index a832f3cc..e4009822 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -144,6 +144,44 @@ "PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n", "PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Viewing Agent history\n", + "\n", + "It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run the training session to generate some resultant data.\n", + "for i in range(100):\n", + " env.step(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "attacker = env.game.agents[\"data_manipulation_attacker\"]\n", + "\n", + "attacker.show_history()" + ] } ], "metadata": { @@ -162,7 +200,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.11" } }, "nbformat": 4, From c8f6459af6022f2536580f968c04e9d32b15e596 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Sep 2024 10:09:10 +0100 Subject: [PATCH 50/72] #2871 - Changelog and documentation updates, corrected changes in Data manipulation demo notebook --- CHANGELOG.md | 1 + docs/source/configuration/agents.rst | 1 + .../notebooks/Data-Manipulation-E2E-Demonstration.ipynb | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44f1ec29..b7f8a26e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Log observation space data by episode and step. +- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted. ### Changed - ACL's are no longer applied to layer-2 traffic. diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index dece94c5..0bc586e8 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -177,3 +177,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef ----------------- Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation. +A log of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``. diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index c1b959f5..13533097 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -681,7 +681,7 @@ "metadata": {}, "outputs": [], "source": [ - "env.game.agents[\"data_manipulation_attacker\"].show_history(ignored_actions=[\"\"])" + "env.game.agents[\"data_manipulation_attacker\"].show_history()" ] }, { From ccb91869c4e7b62e5772c09de496c5cc96b7d35a Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Sep 2024 10:17:18 +0100 Subject: [PATCH 51/72] #2871 - Minor wording change to description in agents.rst --- docs/source/configuration/agents.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index 0bc586e8..74571cf2 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -177,4 +177,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef ----------------- Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation. -A log of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``. +A summary of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``. From 3a5b75239d64c6febe35ac4bae227e8c804a8f01 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Sep 2024 12:05:40 +0100 Subject: [PATCH 52/72] #2871 - Typo in Changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7f8a26e..b81e256b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [3.3.0] - 2024-09-04 +## [3.4.0] ### Added - Log observation space data by episode and step. From 4391d7cdd559619a1bbb9c4c8b25936bad2d3290 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Tue, 17 Sep 2024 12:19:35 +0100 Subject: [PATCH 53/72] #2445: added the ability to pass the game options thresholds to observations so that relevant observation items can retrieve the thresholds from config --- .../observations/file_system_observations.py | 50 +++++++++++++++++-- .../agent/observations/host_observations.py | 17 ++++++- .../agent/observations/nic_observations.py | 35 ++++++++++--- .../agent/observations/node_observations.py | 6 +++ .../agent/observations/observation_manager.py | 12 +++-- .../game/agent/observations/observations.py | 36 ++++++++++++- .../observations/software_observation.py | 30 +++++++---- src/primaite/game/game.py | 2 +- .../configs/basic_switched_network.yaml | 25 +++++++++- .../test_game_options_config.py | 41 ++++++++++++++- .../test_file_system_observations.py | 32 ++++++++++++ .../observations/test_nic_observations.py | 25 ++++------ .../test_software_observations.py | 27 ++++++++++ 13 files changed, 290 insertions(+), 48 deletions(-) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 1c73d026..fe959c9f 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, identifier="FILE"): file_system_requires_scan: Optional[bool] = None """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: + def __init__( + self, + where: WhereType, + include_num_access: bool, + file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a file observation instance. @@ -48,10 +54,22 @@ class FileObservation(AbstractObservation, identifier="FILE"): if self.include_num_access: self.default_observation["num_access"] = 0 - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("file_access") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ], + threshold_identifier="file_access", + ): + self.low_threshold = thresholds.get("file_access")["low"] + self.med_threshold = thresholds.get("file_access")["medium"] + self.high_threshold = thresholds.get("file_access")["high"] def _categorise_num_access(self, num_access: int) -> int: """ @@ -122,6 +140,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) @@ -149,6 +168,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): num_files: int, include_num_access: bool, file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, ) -> None: """ Initialise a folder observation instance. @@ -170,6 +190,23 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self.file_system_requires_scan: bool = file_system_requires_scan + if thresholds.get("file_access") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ], + threshold_identifier="file_access", + ): + self.low_threshold = thresholds.get("file_access")["low"] + self.med_threshold = thresholds.get("file_access")["medium"] + self.high_threshold = thresholds.get("file_access")["high"] + self.files: List[FileObservation] = files while len(self.files) < num_files: self.files.append( @@ -177,6 +214,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): where=None, include_num_access=include_num_access, file_system_requires_scan=self.file_system_requires_scan, + thresholds=thresholds, ) ) while len(self.files) > num_files: @@ -248,6 +286,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): for file_config in config.files: file_config.include_num_access = config.include_num_access file_config.file_system_requires_scan = config.file_system_requires_scan + file_config.thresholds = config.thresholds files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls( @@ -256,4 +295,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): num_files=config.num_files, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 4419ccc7..fa7ceae5 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -151,7 +151,13 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.nics: List[NICObservation] = network_interfaces while len(self.nics) < num_nics: - self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic)) + self.nics.append( + NICObservation( + where=None, + include_nmne=include_nmne, + monitored_traffic=monitored_traffic, + ) + ) while len(self.nics) > num_nics: truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" @@ -257,12 +263,16 @@ class HostObservation(AbstractObservation, identifier="HOST"): where = parent_where + [config.hostname] # Pass down shared/common config items + for app_config in config.applications: + app_config.thresholds = config.thresholds for folder_config in config.folders: folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files folder_config.file_system_requires_scan = config.file_system_requires_scan + folder_config.thresholds = config.thresholds for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + nic_config.thresholds = config.thresholds services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] @@ -273,7 +283,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): count = 1 while len(nics) < config.num_nics: nic_config = NICObservation.ConfigSchema( - nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic + nic_num=count, + include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) nics.append(NICObservation.from_config(config=nic_config, parent_where=where)) count += 1 diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 002ee4da..48fa11dc 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -24,7 +24,13 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): monitored_traffic: Optional[Dict] = None """A dict containing which traffic types are to be included in the observation.""" - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: + def __init__( + self, + where: WhereType, + include_nmne: bool, + monitored_traffic: Optional[Dict] = None, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a network interface observation instance. @@ -44,10 +50,22 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.nmne_inbound_last_step: int = 0 self.nmne_outbound_last_step: int = 0 - # TODO: allow these to be configured in yaml - self.high_nmne_threshold = 10 - self.med_nmne_threshold = 5 - self.low_nmne_threshold = 0 + if thresholds.get("nmne") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("nmne")["low"], + thresholds.get("nmne")["medium"], + thresholds.get("nmne")["high"], + ], + threshold_identifier="nmne", + ): + self.low_threshold = thresholds.get("nmne")["low"] + self.med_threshold = thresholds.get("nmne")["medium"] + self.high_threshold = thresholds.get("nmne")["high"] self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -86,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): :param nmne_count: Number of MNEs detected. :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. """ - if nmne_count > self.high_nmne_threshold: + if nmne_count > self.high_threshold: return 3 - elif nmne_count > self.med_nmne_threshold: + elif nmne_count > self.med_threshold: return 2 - elif nmne_count > self.low_nmne_threshold: + elif nmne_count > self.low_threshold: return 1 return 0 @@ -224,6 +242,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index e263cadb..91bf402e 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -195,6 +195,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): host_config.file_system_requires_scan = config.file_system_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users + if host_config.thresholds is None: + host_config.thresholds = config.thresholds for router_config in config.routers: if router_config.num_ports is None: @@ -211,6 +213,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.num_rules = config.num_rules if router_config.include_users is None: router_config.include_users = config.include_users + if router_config.thresholds is None: + router_config.thresholds = config.thresholds for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -225,6 +229,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewall_config.num_rules = config.num_rules if firewall_config.include_users is None: firewall_config.include_users = config.include_users + if firewall_config.thresholds is None: + firewall_config.thresholds = config.thresholds hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 9b20fdcb..cc32918c 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -113,7 +113,9 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options)) + obs_instance = obs_class.from_config( + config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds) + ) instances[component.label] = obs_instance return cls(components=instances) @@ -176,7 +178,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Optional[Dict]) -> "ObservationManager": + def from_config(cls, config: Optional[Dict], thresholds: Optional[Dict] = {}) -> "ObservationManager": """ Create observation space from a config. @@ -187,11 +189,15 @@ class ObservationManager: AbstractObservation options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict + :param thresholds: Dictionary containing the observation thresholds. + :type thresholds: Optional[Dict] """ if config is None: return cls(NullObservation()) obs_type = config["type"] obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"])) + observation = obs_class.from_config( + config=obs_class.ConfigSchema(**config["options"], thresholds=thresholds), + ) obs_manager = cls(observation) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index a9663c56..0b209f52 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from gymnasium import spaces from gymnasium.core import ObsType @@ -19,6 +19,9 @@ class AbstractObservation(ABC): class ConfigSchema(ABC, BaseModel): """Config schema for observations.""" + thresholds: Optional[Dict] = None + """A dict containing the observation thresholds.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -67,3 +70,34 @@ class AbstractObservation(ABC): def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() + + def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool: + """ + Method that checks if the thresholds are non overlapping and in the correct (ascending) order. + + Pass in the thresholds from low to high e.g. + thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold] + + Throws an error if the threshold is not valid + + :param: thresholds: List of thresholds in ascending order. + :type: List[int] + :param: threshold_identifier: The name of the threshold option. + :type: Optional[str] + + :returns: bool + """ + if thresholds is None or len(thresholds) < 2: + raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}") + for idx in range(1, len(thresholds)): + if not isinstance(thresholds[idx], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + if not isinstance(thresholds[idx - 1], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + + if thresholds[idx] <= thresholds[idx - 1]: + raise Exception( + f"{threshold_identifier} threshold ({thresholds[idx]}) " + f"is greater than or equal to ({thresholds[idx - 1]}.)" + ) + return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 15cd2447..10adb5c5 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict +from typing import Dict, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): application_name: str """Name of the application, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + def __init__(self, where: WhereType, thresholds: Optional[Dict] = {}) -> None: """ Initialise an application observation instance. @@ -94,16 +94,28 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): self.where = where self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("app_executions") is None: + self.low_threshold = 0 + self.med_threshold = 5 + self.high_threshold = 10 + else: + if self._validate_thresholds( + thresholds=[ + thresholds.get("app_executions")["low"], + thresholds.get("app_executions")["medium"], + thresholds.get("app_executions")["high"], + ], + threshold_identifier="app_executions", + ): + self.low_threshold = thresholds.get("app_executions")["low"] + self.med_threshold = thresholds.get("app_executions")["medium"] + self.high_threshold = thresholds.get("app_executions")["high"] def _categorise_num_executions(self, num_executions: int) -> int: """ - Represent number of file accesses as a categorical variable. + Represent number of application executions as a categorical variable. - :param num_access: Number of file accesses. + :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ if num_executions > self.high_threshold: @@ -161,4 +173,4 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["applications", config.application_name]) + return cls(where=parent_where + ["applications", config.application_name], thresholds=config.thresholds) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 123b6ddd..441ea632 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -531,7 +531,7 @@ class PrimaiteGame: reward_function_cfg = agent_cfg["reward_function"] # CREATE OBSERVATION SPACE - obs_space = ObservationManager.from_config(observation_space_cfg) + obs_space = ObservationManager.from_config(config=observation_space_cfg, thresholds=game.options.thresholds) # CREATE ACTION SPACE action_space = ActionManager.from_config(game, action_space_cfg) diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index fed0f52d..03cf2207 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -25,7 +25,19 @@ game: - ICMP - TCP - UDP - + thresholds: + nmne: + high: 100 + medium: 25 + low: 5 + file_access: + high: 10 + medium: 5 + low: 2 + app_executions: + high: 5 + medium: 3 + low: 2 agents: - ref: client_2_green_user team: GREEN @@ -79,10 +91,16 @@ agents: options: hosts: - hostname: client_1 + applications: + - application_name: WebBrowser + folders: + - folder_name: root + files: + - file_name: "test.txt" - hostname: client_2 - hostname: client_3 num_services: 1 - num_applications: 0 + num_applications: 1 num_folders: 1 num_files: 1 num_nics: 2 @@ -219,6 +237,9 @@ simulation: options: ntp_server_ip: 192.168.1.10 - type: NTPServer + file_system: + - root: + - "test.txt" - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 32d88c92..2cb5520e 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from tests import TEST_ASSETS_ROOT -BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" +BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" def load_config(config_path: Union[str, Path]) -> PrimaiteGame: @@ -24,3 +24,42 @@ def test_thresholds(): game = load_config(data_manipulation_config_path()) assert game.options.thresholds is not None + + +def test_nmne_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["nmne"] is not None + + # get NIC observation + nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] + assert nic_obs.low_threshold == 5 + assert nic_obs.med_threshold == 25 + assert nic_obs.high_threshold == 100 + + +def test_file_access_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["file_access"] is not None + + # get file observation + file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] + assert file_obs.low_threshold == 2 + assert file_obs.med_threshold == 5 + assert file_obs.high_threshold == 10 + + +def test_app_executions_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["app_executions"] is not None + + # get application observation + app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] + assert app_obs.low_threshold == 2 + assert app_obs.med_threshold == 3 + assert app_obs.high_threshold == 5 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index e2ab2990..cbd9f8c0 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -44,6 +44,38 @@ def test_file_observation(simulation): assert observation_state.get("health_status") == 3 # corrupted +def test_config_file_access_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + file_obs = FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert file_obs.high_threshold == 9 + assert file_obs.med_threshold == 6 + assert file_obs.low_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}}, + ) + + def test_folder_observation(simulation): """Test the folder observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index ef789ba7..cafdec45 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -110,33 +110,28 @@ def test_nic_categories(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) - assert nic_obs.high_nmne_threshold == 10 # default - assert nic_obs.med_nmne_threshold == 5 # default - assert nic_obs.low_nmne_threshold == 0 # default + assert nic_obs.high_threshold == 10 # default + assert nic_obs.med_threshold == 5 # default + assert nic_obs.low_threshold == 0 # default -@pytest.mark.skip(reason="Feature not implemented yet") def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=6, - high_nmne_threshold=9, + thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) - assert nic_obs.high_nmne_threshold == 9 - assert nic_obs.med_nmne_threshold == 6 - assert nic_obs.low_nmne_threshold == 3 + assert nic_obs.high_threshold == 9 + assert nic_obs.med_threshold == 6 + assert nic_obs.low_threshold == 3 with pytest.raises(Exception): # should throw an error NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=9, - med_nmne_threshold=6, - high_nmne_threshold=9, + thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -144,9 +139,7 @@ def test_config_nic_categories(simulation): # should throw an error NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=9, - high_nmne_threshold=9, + thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 998aa755..25081585 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -69,3 +69,30 @@ def test_application_observation(simulation): assert observation_state.get("health_status") == 1 assert observation_state.get("operating_status") == 1 # running assert observation_state.get("num_executions") == 1 + + +def test_application_executions_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + app_obs = ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert app_obs.high_threshold == 9 + assert app_obs.med_threshold == 6 + assert app_obs.low_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.hostname, "applications", "WebBrowser"], + thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}}, + ) From 8d3760b5a7e8bf53f8a7e20cabc3a5597ecd897f Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 17 Sep 2024 16:19:43 +0100 Subject: [PATCH 54/72] #2871 - Fix notebook failure --- .../notebooks/Getting-Information-Out-Of-PrimAITE.ipynb | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index e4009822..6a60c1bc 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -160,6 +160,11 @@ "metadata": {}, "outputs": [], "source": [ + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", "# Run the training session to generate some resultant data.\n", "for i in range(100):\n", " env.step(0)" From 0c576746aa1165aac7aa6fc6eebda68a25945249 Mon Sep 17 00:00:00 2001 From: Nick Todd Date: Thu, 19 Sep 2024 11:07:00 +0100 Subject: [PATCH 55/72] #2896: Bump version. --- src/primaite/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 15a27998..688932aa 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -3.3.0 +3.4.0-dev From 88cbb783bc6cd11aed890671be0eb4fa02e371e1 Mon Sep 17 00:00:00 2001 From: Archer Bowen Date: Fri, 20 Sep 2024 13:54:13 +0100 Subject: [PATCH 56/72] #2840 Fixed sphinx user guide formatting issues. --- docs/source/request_system.rst | 2 ++ .../system/services/terminal.rst | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index f2d2e68d..6b71bf25 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -2,6 +2,8 @@ © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +.. _request_system: + Request System ************** diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index de6eaf0a..de0bb026 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -38,31 +38,31 @@ Implementation - Manages remote connections in a dictionary by session ID. - Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate. - Extends Service class. - - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. Command Format ^^^^^^^^^^^^^^ -``Terminals`` implement their commands through leveraging the pre-existing :doc:`../../request_system`. +Terminals implement their commands through leveraging the pre-existing :ref:`request_system`. -Due to this ``Terminals`` will only accept commands passed within the ``RequestFormat``. +Due to this Terminals will only accept commands passed within the ``RequestFormat``. :py:class:`primaite.game.interface.RequestFormat` For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows: .. code-block:: yaml + command: - "file_system" - "create" - "file" - "downloads" - "cat.png" - - "False" + - "False -**This command creates file called ``cat.png`` within the ``downloads`` folder.** - -This is then loaded from ``yaml`` into a dictionary containing the terminal command: +This is then loaded from yaml into a dictionary containing the terminal command: .. code-block:: python From e29815305dd4eaf0294b73a6b0d4a2e3ba4ccb75 Mon Sep 17 00:00:00 2001 From: Archer Bowen Date: Tue, 24 Sep 2024 11:06:38 +0100 Subject: [PATCH 57/72] #2840 Addressing PR comments. --- .../simulation_components/system/services/terminal.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index de0bb026..b11d74bb 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -26,7 +26,7 @@ Key capabilities Usage """"" - - Pre-Installs on any `Node` (component with the exception of `Switches`). + - Pre-Installs on any `Node` component (with the exception of `Switches`). - Terminal Clients connect, execute commands and disconnect from remote nodes. - Ensures that users are logged in to the component before executing any commands. - Service runs on SSH port 22 by default. @@ -68,7 +68,7 @@ This is then loaded from yaml into a dictionary containing the terminal command: {"command":["file_system", "create", "file", "downloads", "cat.png", "False"]} -Which is then parsed to the ``Terminals`` Request Manager to be executed. +Which is then passed to the ``Terminals`` Request Manager to be executed. Game Layer Usage (Agents) ======================== @@ -121,7 +121,7 @@ Agents are able to use the terminal to login into remote nodes via ``SSH`` which ``NODE_SEND_REMOTE_COMMAND`` """""""""""""""""""""""""""" -After remotely login into another host, a agent can use the ``NODE_SEND_REMOTE_COMMAND`` to execute commands across the network remotely. +After remotely logging into another host, an agent can use the ``NODE_SEND_REMOTE_COMMAND`` to execute commands across the network remotely. .. code-block:: yaml From b9df2bd6a8d4e1259213bd08fd61c489fc26a845 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Wed, 25 Sep 2024 10:50:26 +0100 Subject: [PATCH 58/72] #2445: apply PR suggestions --- CHANGELOG.md | 1 + .../observations/file_system_observations.py | 57 +++++++++---------- .../agent/observations/nic_observations.py | 38 ++++++++----- .../game/agent/observations/observations.py | 4 +- .../observations/software_observation.py | 42 +++++++++----- .../test_game_options_config.py | 18 +++--- .../test_file_system_observations.py | 6 +- .../observations/test_nic_observations.py | 12 ++-- .../test_software_observations.py | 6 +- 9 files changed, 103 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 056742e4..c748a969 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Log observation space data by episode and step. - Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted. - New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to be able to set the observation threshold for NMNE, file access and application executions ### Changed - ACL's are no longer applied to layer-2 traffic. diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index fe959c9f..b24b26a6 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -55,21 +55,35 @@ class FileObservation(AbstractObservation, identifier="FILE"): self.default_observation["num_access"] = 0 if thresholds.get("file_access") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 + self.low_file_access_threshold = 0 + self.med_file_access_threshold = 5 + self.high_file_access_threshold = 10 else: - if self._validate_thresholds( + self._set_file_access_threshold( thresholds=[ thresholds.get("file_access")["low"], thresholds.get("file_access")["medium"], thresholds.get("file_access")["high"], - ], - threshold_identifier="file_access", - ): - self.low_threshold = thresholds.get("file_access")["low"] - self.med_threshold = thresholds.get("file_access")["medium"] - self.high_threshold = thresholds.get("file_access")["high"] + ] + ) + + def _set_file_access_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the file access threshold. + + :param: thresholds: The file access threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="file_access", + ): + self.low_file_access_threshold = thresholds[0] + self.med_file_access_threshold = thresholds[1] + self.high_file_access_threshold = thresholds[2] def _categorise_num_access(self, num_access: int) -> int: """ @@ -78,11 +92,11 @@ class FileObservation(AbstractObservation, identifier="FILE"): :param num_access: Number of file accesses. :return: Bin number corresponding to the number of accesses. """ - if num_access > self.high_threshold: + if num_access > self.high_file_access_threshold: return 3 - elif num_access > self.med_threshold: + elif num_access > self.med_file_access_threshold: return 2 - elif num_access > self.low_threshold: + elif num_access > self.low_file_access_threshold: return 1 return 0 @@ -190,23 +204,6 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self.file_system_requires_scan: bool = file_system_requires_scan - if thresholds.get("file_access") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 - else: - if self._validate_thresholds( - thresholds=[ - thresholds.get("file_access")["low"], - thresholds.get("file_access")["medium"], - thresholds.get("file_access")["high"], - ], - threshold_identifier="file_access", - ): - self.low_threshold = thresholds.get("file_access")["low"] - self.med_threshold = thresholds.get("file_access")["medium"] - self.high_threshold = thresholds.get("file_access")["high"] - self.files: List[FileObservation] = files while len(self.files) < num_files: self.files.append( diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 30ee240d..0dabd9f4 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import ClassVar, Dict, Optional +from typing import ClassVar, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -55,21 +55,17 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.nmne_outbound_last_step: int = 0 if thresholds.get("nmne") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 + self.low_nmne_threshold = 0 + self.med_nmne_threshold = 5 + self.high_nmne_threshold = 10 else: - if self._validate_thresholds( + self._set_nmne_threshold( thresholds=[ thresholds.get("nmne")["low"], thresholds.get("nmne")["medium"], thresholds.get("nmne")["high"], - ], - threshold_identifier="nmne", - ): - self.low_threshold = thresholds.get("nmne")["low"] - self.med_threshold = thresholds.get("nmne")["medium"] - self.high_threshold = thresholds.get("nmne")["high"] + ] + ) self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -108,11 +104,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): :param nmne_count: Number of MNEs detected. :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. """ - if nmne_count > self.high_threshold: + if nmne_count > self.high_nmne_threshold: return 3 - elif nmne_count > self.med_threshold: + elif nmne_count > self.med_nmne_threshold: return 2 - elif nmne_count > self.low_threshold: + elif nmne_count > self.low_nmne_threshold: return 1 return 0 @@ -126,6 +122,20 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): bandwidth_utilisation = traffic_value / nic_max_bandwidth return int(bandwidth_utilisation * 9) + 1 + def _set_nmne_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the NMNE threshold. + + :param: thresholds: The NMNE threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=thresholds, + threshold_identifier="nmne", + ): + self.low_nmne_threshold = thresholds[0] + self.med_nmne_threshold = thresholds[1] + self.high_nmne_threshold = thresholds[2] + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 0b209f52..7a31a26b 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -97,7 +97,7 @@ class AbstractObservation(ABC): if thresholds[idx] <= thresholds[idx - 1]: raise Exception( - f"{threshold_identifier} threshold ({thresholds[idx]}) " - f"is greater than or equal to ({thresholds[idx - 1]}.)" + f"{threshold_identifier} threshold ({thresholds[idx - 1]}) " + f"is greater than or equal to ({thresholds[idx]}.)" ) return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 10ffe3fc..0318c864 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, Optional +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -109,21 +109,35 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} if thresholds.get("app_executions") is None: - self.low_threshold = 0 - self.med_threshold = 5 - self.high_threshold = 10 + self.low_app_execution_threshold = 0 + self.med_app_execution_threshold = 5 + self.high_app_execution_threshold = 10 else: - if self._validate_thresholds( + self._set_application_execution_thresholds( thresholds=[ thresholds.get("app_executions")["low"], thresholds.get("app_executions")["medium"], thresholds.get("app_executions")["high"], - ], - threshold_identifier="app_executions", - ): - self.low_threshold = thresholds.get("app_executions")["low"] - self.med_threshold = thresholds.get("app_executions")["medium"] - self.high_threshold = thresholds.get("app_executions")["high"] + ] + ) + + def _set_application_execution_thresholds(self, thresholds: List[int]): + """ + Method that validates and then sets the application execution threshold. + + :param: thresholds: The application execution threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="app_executions", + ): + self.low_app_execution_threshold = thresholds[0] + self.med_app_execution_threshold = thresholds[1] + self.high_app_execution_threshold = thresholds[2] def _categorise_num_executions(self, num_executions: int) -> int: """ @@ -132,11 +146,11 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ - if num_executions > self.high_threshold: + if num_executions > self.high_app_execution_threshold: return 3 - elif num_executions > self.med_threshold: + elif num_executions > self.med_app_execution_threshold: return 2 - elif num_executions > self.low_threshold: + elif num_executions > self.low_app_execution_threshold: return 1 return 0 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 2cb5520e..4098db7f 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -34,9 +34,9 @@ def test_nmne_threshold(): # get NIC observation nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] - assert nic_obs.low_threshold == 5 - assert nic_obs.med_threshold == 25 - assert nic_obs.high_threshold == 100 + assert nic_obs.low_nmne_threshold == 5 + assert nic_obs.med_nmne_threshold == 25 + assert nic_obs.high_nmne_threshold == 100 def test_file_access_threshold(): @@ -47,9 +47,9 @@ def test_file_access_threshold(): # get file observation file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] - assert file_obs.low_threshold == 2 - assert file_obs.med_threshold == 5 - assert file_obs.high_threshold == 10 + assert file_obs.low_file_access_threshold == 2 + assert file_obs.med_file_access_threshold == 5 + assert file_obs.high_file_access_threshold == 10 def test_app_executions_threshold(): @@ -60,6 +60,6 @@ def test_app_executions_threshold(): # get application observation app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] - assert app_obs.low_threshold == 2 - assert app_obs.med_threshold == 3 - assert app_obs.high_threshold == 5 + assert app_obs.low_app_execution_threshold == 2 + assert app_obs.med_app_execution_threshold == 3 + assert app_obs.high_app_execution_threshold == 5 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index cbd9f8c0..6356c297 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -53,9 +53,9 @@ def test_config_file_access_categories(simulation): thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, ) - assert file_obs.high_threshold == 9 - assert file_obs.med_threshold == 6 - assert file_obs.low_threshold == 3 + assert file_obs.high_file_access_threshold == 9 + assert file_obs.med_file_access_threshold == 6 + assert file_obs.low_file_access_threshold == 3 with pytest.raises(Exception): # should throw an error diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 9b2baf25..d01d0c8e 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -118,9 +118,9 @@ def test_nic_categories(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) - assert nic_obs.high_threshold == 10 # default - assert nic_obs.med_threshold == 5 # default - assert nic_obs.low_threshold == 0 # default + assert nic_obs.high_nmne_threshold == 10 # default + assert nic_obs.med_nmne_threshold == 5 # default + assert nic_obs.low_nmne_threshold == 0 # default def test_config_nic_categories(simulation): @@ -131,9 +131,9 @@ def test_config_nic_categories(simulation): include_nmne=True, ) - assert nic_obs.high_threshold == 9 - assert nic_obs.med_threshold == 6 - assert nic_obs.low_threshold == 3 + assert nic_obs.high_nmne_threshold == 9 + assert nic_obs.med_nmne_threshold == 6 + assert nic_obs.low_nmne_threshold == 3 with pytest.raises(Exception): # should throw an error diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 22374718..a0637969 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -84,9 +84,9 @@ def test_application_executions_categories(simulation): thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, ) - assert app_obs.high_threshold == 9 - assert app_obs.med_threshold == 6 - assert app_obs.low_threshold == 3 + assert app_obs.high_app_execution_threshold == 9 + assert app_obs.med_app_execution_threshold == 6 + assert app_obs.low_app_execution_threshold == 3 with pytest.raises(Exception): # should throw an error From 603c68acf9cc2b7ef6c247e3b8eab26040f62e8a Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 26 Sep 2024 08:51:30 +0100 Subject: [PATCH 59/72] #2445: grammar in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c748a969..bd4b992c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Log observation space data by episode and step. - Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `DONOTHING` actions are omitted. - New ``NODE_SEND_LOCAL_COMMAND`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) -- Added ability to be able to set the observation threshold for NMNE, file access and application executions +- Added ability to set the observation threshold for NMNE, file access and application executions ### Changed - ACL's are no longer applied to layer-2 traffic. From 17fe5cb043341f5ab52db433a059cbe36276cf77 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 27 Sep 2024 10:47:38 +0100 Subject: [PATCH 60/72] #2897: How to guide on dev mode --- .../How-To-Use-Primaite-Dev-Mode.ipynb | 479 ++++++++++++++++++ 1 file changed, 479 insertions(+) create mode 100644 src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb diff --git a/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb new file mode 100644 index 00000000..8f8ec24b --- /dev/null +++ b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PrimAITE Developer mode\n", + "\n", + "PrimAITE has built in developer tools.\n", + "\n", + "The dev-mode is designed to help make the development of PrimAITE easier.\n", + "\n", + "`NOTE: For the purposes of the notebook, the commands are preceeded by \"!\". When running the commands, run it without the \"!\".`\n", + "\n", + "To display the available dev-mode options, run the command below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode --help" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the current PRIMAITE_CONFIG to restore after the notebook runs\n", + "\n", + "from primaite import PRIMAITE_CONFIG\n", + "\n", + "temp_config = PRIMAITE_CONFIG.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dev mode options" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### enable\n", + "\n", + "Enables the dev mode for PrimAITE.\n", + "\n", + "This will enable the developer mode for PrimAITE.\n", + "\n", + "By default, when developer mode is enabled, session logs will be generated in the PRIMAITE_ROOT/sessions folder unless configured to be generated in another location." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode enable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### disable\n", + "\n", + "Disables the dev mode for PrimAITE.\n", + "\n", + "This will disable the developer mode for PrimAITE." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode disable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### show\n", + "\n", + "Shows if PrimAITE is running in dev mode or production mode.\n", + "\n", + "The command will also show the developer mode configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode show" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### config\n", + "\n", + "Configure the PrimAITE developer mode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### path\n", + "\n", + "Set the path where generated session files will be output.\n", + "\n", + "By default, this value will be in PRIMAITE_ROOT/sessions.\n", + "\n", + "To reset the path to default, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config path -root\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config path --default" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --sys-log-level or -slevel\n", + "\n", + "Set the system log level.\n", + "\n", + "This will override the system log level in configurations and will make PrimAITE include the set log level and above.\n", + "\n", + "Available options are:\n", + "- `DEBUG`\n", + "- `INFO`\n", + "- `WARNING`\n", + "- `ERROR`\n", + "- `CRITICAL`\n", + "\n", + "Default value is `DEBUG`\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --sys-log-level DEBUG\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -slevel DEBUG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --agent-log-level or -alevel\n", + "\n", + "Set the agent log level.\n", + "\n", + "This will override the agent log level in configurations and will make PrimAITE include the set log level and above.\n", + "\n", + "Available options are:\n", + "- `DEBUG`\n", + "- `INFO`\n", + "- `WARNING`\n", + "- `ERROR`\n", + "- `CRITICAL`\n", + "\n", + "Default value is `DEBUG`\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --agent-log-level DEBUG\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -alevel DEBUG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-sys-logs or -sys\n", + "\n", + "If enabled, developer mode will output system logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-sys-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting sys logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-sys-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nsys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-agent-logs or -agent\n", + "\n", + "If enabled, developer mode will output agent action logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-agent-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting agent action logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-agent-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nagent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-pcap-logs or -pcap\n", + "\n", + "If enabled, developer mode will output PCAP logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-pcap-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -pcap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting PCAP logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-pcap-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -npcap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-to-terminal or -t\n", + "\n", + "If enabled, developer mode will output logs to the terminal.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-to-terminal\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -t" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable terminal outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-terminal\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining commands\n", + "\n", + "It is possible to combine commands to set the configuration.\n", + "\n", + "This saves having to enter multiple commands and allows for a much more efficient setting of PrimAITE developer mode configurations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example of setting system log level and enabling the system logging:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config -slevel WARNING -sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another example where the system log and agent action log levels are set and enabled and should be printed to terminal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config -slevel ERROR -sys -alevel ERROR -agent -t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Restore PRIMAITE_CONFIG\n", + "from primaite.utils.cli.primaite_config_utils import update_primaite_application_config\n", + "\n", + "\n", + "global PRIMAITE_CONFIG\n", + "PRIMAITE_CONFIG[\"developer_mode\"] = temp_config[\"developer_mode\"]\n", + "update_primaite_application_config(config=PRIMAITE_CONFIG)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From ac921749a7f0f938505b35275ca9c7f8025c2ccb Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 30 Sep 2024 17:38:24 +0100 Subject: [PATCH 61/72] #2900 - Changes to terminal to include a last_response attribute, for use in obtaining RequestResponse from remote command executions --- .../simulator/network/hardware/base.py | 2 +- .../system/services/terminal/terminal.py | 47 +++++++++++++++---- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index f49d0a17..570a69b3 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -1711,7 +1711,7 @@ class Node(SimComponent): """ application_name = request[0] if self.software_manager.software.get(application_name): - self.sys_log.warning(f"Can't install {application_name}. It's already installed.") + self.sys_log.info(f"Can't install {application_name}. It's already installed.") return RequestResponse(status="success", data={"reason": "already installed"}) application_class = Application._application_registry[application_name] self.software_manager.install(application_class) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index dc7da205..77c67460 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -135,12 +135,20 @@ class Terminal(Service): _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} """Dictionary of connect requests made to remote nodes.""" + _last_response: Optional[RequestResponse] = None + """Last response received from RequestManager, for returning remote RequestResponse.""" + def __init__(self, **kwargs): kwargs["name"] = "Terminal" kwargs["port"] = Port.SSH kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) + @property + def last_response(self) -> Optional[RequestResponse]: + """Public version of _last_response attribute.""" + return self._last_response + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -202,12 +210,8 @@ class Terminal(Service): command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: - outcome = remote_connection.execute(command) - if outcome: - return RequestResponse( - status="success", - data={}, - ) + remote_connection.execute(command) + return self.last_response if not None else RequestResponse(status="failure", data={}) return RequestResponse( status="failure", data={}, @@ -243,7 +247,8 @@ class Terminal(Service): def execute(self, command: List[Any]) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) + self._last_response = self.parent.apply_request(command) + return self._last_response def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: """Find Remote Terminal Connection from a given IP.""" @@ -423,10 +428,11 @@ class Terminal(Service): """ source_ip = kwargs["frame"].ip.src_ip_address self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") + self._last_response = None # Clear last response + if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection - # TODO: uncomment this as part of 2781 username = payload.user_account.username password = payload.user_account.password connection_id = self.parent.user_session_manager.remote_login( @@ -472,6 +478,9 @@ class Terminal(Service): session_id=session_id, source_ip=source_ip, ) + self._last_response: RequestResponse = RequestResponse( + status="success", data={"reason": "Login Successful"} + ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed @@ -483,12 +492,32 @@ class Terminal(Service): payload.connection_uuid ) remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep - self.execute(command) + self._last_response: RequestResponse = self.execute(command) + + if self._last_response.status == "success": + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + else: + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED + + payload: SSHPacket = SSHPacket( + payload=self._last_response, + transport_message=transport_message, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + ) + self.software_manager.send_payload_to_session_manager( + payload=payload, dest_port=self.port, session_id=session_id + ) return True else: self.sys_log.error( f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." ) + elif ( + payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + or SSHTransportMessage.SSH_MSG_SERVICE_FAILED + ): + # Likely receiving command ack from remote. + self._last_response = payload.payload if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": From 3dafad71b32dc21eea918e145dbc89a232c59794 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 1 Oct 2024 10:45:03 +0100 Subject: [PATCH 62/72] #2900 - New test to show that last_response updates as expected. Changelog updated. --- CHANGELOG.md | 1 + .../system/services/terminal/terminal.py | 2 +- .../_system/_services/test_terminal.py | 37 +++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd4b992c..4a1f7919 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 or `generate_seed_value` is set to `true`. - ARP .show() method will now include the port number associated with each entry. - Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). +- Updated the `Terminal` class to provide response information when sending remote command execution. ## [3.3.0] - 2024-09-04 ### Added diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 77c67460..ed6854f4 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -214,7 +214,7 @@ class Terminal(Service): return self.last_response if not None else RequestResponse(status="failure", data={}) return RequestResponse( status="failure", - data={}, + data={"reason": "Failed to execute command."}, ) rm.add_request( diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 3c3daa61..14cc5877 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -6,6 +6,7 @@ import pytest from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.interface.request import RequestResponse from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server @@ -403,3 +404,39 @@ def test_terminal_connection_timeout(basic_network): assert len(computer_b.user_session_manager.remote_sessions) == 0 assert not remote_connection.is_active + + +def test_terminal_last_response_updates(basic_network): + """Test that the _last_response within Terminal correctly updates.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("Terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + assert terminal_a.last_response is None + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + # Last response should be a successful logon + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"}) + + remote_connection.execute(command=["software_manager", "application", "install", "RansomwareScript"]) + + # Last response should now update following successful install + assert terminal_a.last_response == RequestResponse(status="success", data={}) + + remote_connection.execute(command=["software_manager", "application", "install", "RansomwareScript"]) + + # Last response should now update to success, but with supplied reason. + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"}) + + remote_connection.execute(command=["file_system", "create", "file", "folder123", "doggo.pdf", False]) + + # Check file was created. + assert computer_b.file_system.access_file(folder_name="folder123", file_name="doggo.pdf") + + # Last response should be confirmation of file creation. + assert terminal_a.last_response == RequestResponse( + status="success", + data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400}, + ) From 1e1d1524810b271cc60c5393d123d945aa4d6c1f Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 1 Oct 2024 11:02:23 +0100 Subject: [PATCH 63/72] #2900 - Updates to Terminal-Processing jupyter notebook to include a mention of last_response --- .../notebooks/Terminal-Processing.ipynb | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 19ce567e..2ab06a5c 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -167,6 +167,22 @@ "computer_b.file_system.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(terminal_a.last_response)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -488,7 +504,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, From fcfea3474fe1f5701faa07ab245f60978d8246b4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 1 Oct 2024 11:41:42 +0100 Subject: [PATCH 64/72] #2900 - typo in test_ftp_client and expanded test_terminal_last_response_updates to include a failure scenario --- .../_system/_services/test_ftp_client.py | 2 +- .../_system/_services/test_terminal.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 3ce4d8ee..99bb42ed 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -71,7 +71,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR -def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client): +def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client): """Method send_file should return false if no file to send.""" assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 14cc5877..55f89c04 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -440,3 +440,23 @@ def test_terminal_last_response_updates(basic_network): status="success", data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400}, ) + + remote_connection.execute( + command=[ + "service", + "FTPClient", + "send", + { + "dest_ip_address": "192.168.0.2", + "src_folder": "folder123", + "src_file_name": "cat.pdf", + "dest_folder": "root", + "dest_file_name": "cat.pdf", + }, + ] + ) + + assert terminal_a.last_response == RequestResponse( + status="failure", + data={"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"}, + ) From 96549e68aa9cddff6e1eecc9da35a288e503d021 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 10 Feb 2025 14:39:28 +0000 Subject: [PATCH 65/72] Merge remote-tracking branch 'origin/dev' into 4.0.0-dev --- CHANGELOG.md | 10 + docs/source/configuration/agents.rst | 3 +- .../nodes/common/common_node_attributes.rst | 33 ++ .../simulation/nodes/network_examples.rst | 12 +- .../configuration/simulation/nodes/router.rst | 6 +- .../how_to_guides/extensible_agents.rst | 6 +- .../source/how_to_guides/extensible_nodes.rst | 8 +- docs/source/request_system.rst | 2 + .../simulation_components/network/network.rst | 8 +- .../network/nodes/wireless_router.rst | 4 +- .../system/applications/c2_suite.rst | 2 +- .../applications/data_manipulation_bot.rst | 4 +- .../system/applications/database_client.rst | 2 +- .../system/applications/ransomware_script.rst | 2 +- .../system/applications/web_browser.rst | 2 +- .../system/services/database_service.rst | 2 +- .../system/services/dns_client.rst | 2 +- .../system/services/dns_server.rst | 2 +- .../system/services/ftp_client.rst | 2 +- .../system/services/ftp_server.rst | 2 +- .../system/services/ntp_client.rst | 2 +- .../system/services/ntp_server.rst | 2 +- .../system/services/terminal.rst | 131 ++++- .../system/services/web_server.rst | 2 +- .../simulation_components/system/software.rst | 2 +- src/primaite/game/agent/actions/node.py | 71 ++- src/primaite/game/agent/actions/session.py | 13 +- src/primaite/game/agent/interface.py | 54 +- .../observations/file_system_observations.py | 53 +- .../agent/observations/host_observations.py | 46 +- .../agent/observations/nic_observations.py | 49 +- .../agent/observations/node_observations.py | 18 +- .../agent/observations/observation_manager.py | 9 +- .../game/agent/observations/observations.py | 36 +- .../observations/software_observation.py | 78 ++- src/primaite/game/game.py | 25 +- ...ommand-and-Control-E2E-Demonstration.ipynb | 174 +++---- ...a-Manipulation-Customising-Red-Agent.ipynb | 56 +- .../Data-Manipulation-E2E-Demonstration.ipynb | 11 +- .../Getting-Information-Out-Of-PrimAITE.ipynb | 45 +- .../How-To-Use-Primaite-Dev-Mode.ipynb | 479 ++++++++++++++++++ .../notebooks/Requests-and-Responses.ipynb | 2 +- .../notebooks/Terminal-Processing.ipynb | 286 ++++++++++- src/primaite/session/environment.py | 30 +- .../simulator/network/hardware/base.py | 22 +- .../network/hardware/nodes/host/host_node.py | 19 +- .../hardware/nodes/network/firewall.py | 2 +- .../network/hardware/nodes/network/router.py | 21 +- .../network/transmission/data_link_layer.py | 2 +- .../simulator/system/services/arp/arp.py | 3 +- .../system/services/terminal/terminal.py | 78 ++- .../system/services/web_server/web_server.py | 55 +- .../configs/basic_switched_network.yaml | 26 +- .../configs/nodes_with_initial_files.yaml | 226 +++++++++ .../test_game_options_config.py | 41 +- .../test_node_file_system_config.py | 64 +++ .../extensions/nodes/giga_switch.py | 2 +- .../actions/test_terminal_actions.py | 54 +- .../actions/test_user_account_actions.py | 176 +++++++ .../test_file_system_observations.py | 32 ++ .../observations/test_nic_observations.py | 27 +- .../observations/test_node_observations.py | 2 + .../observations/test_obs_data_capture.py | 28 + .../test_software_observations.py | 38 +- .../game_layer/test_RNG_seed.py | 22 + .../game_layer/test_actions.py | 10 +- .../network/test_capture_nmne.py | 22 + tests/integration_tests/system/test_arp.py | 21 +- .../_game/_agent/test_observations.py | 227 ++++++++- .../_system/_services/test_ftp_client.py | 2 +- .../_system/_services/test_terminal.py | 57 +++ 71 files changed, 2700 insertions(+), 367 deletions(-) create mode 100644 src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb create mode 100644 tests/assets/configs/nodes_with_initial_files.yaml create mode 100644 tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py create mode 100644 tests/integration_tests/game_layer/actions/test_user_account_actions.py create mode 100644 tests/integration_tests/game_layer/observations/test_obs_data_capture.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f87f54e..871b9923 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [4.0.0] = TBC ### Added +- Log observation space data by episode and step. +- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `do-nothing` actions are omitted. +- New ``node-send-local-command`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only) +- Added ability to set the observation threshold for NMNE, file access and application executions ### Changed - Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty. @@ -24,6 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated tests that don't use YAMLs to still use the new action and agent schemas - Nodes now use a config schema and are extensible, allowing for plugin support. - Node tests have been updated to use the new node config schemas when not using YAML files. +- ACLs are no longer applied to layer-2 traffic. +- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file + or `generate_seed_value` is set to `true`. +- ARP .show() method will now include the port number associated with each entry. +- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning). +- Updated the `Terminal` class to provide response information when sending remote command execution. ### Fixed - DNS client no longer fails to check its cache if a DNS server address is missing. diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index dce8da3a..ee84aede 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -21,7 +21,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo team: GREEN type: probabilistic-agent observation_space: - type: UC2GreenObservation + type: UC2GreenObservation # TODO: what action_space: reward_function: reward_components: @@ -160,3 +160,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef ----------------- Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation. +A summary of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``. diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst index 542b817b..e6d5da67 100644 --- a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -54,6 +54,39 @@ Optional. Default value is ``3``. The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``. +``file_system`` +--------------- + +Optional. + +The file system of the node. This configuration allows nodes to be initialised with files and/or folders. + +The file system takes a list of folders and files. + +Example: + +.. code-block:: yaml + + simulation: + network: + nodes: + - hostname: client_1 + type: computer + ip_address: 192.168.10.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + file_system: + - empty_folder # example of an empty folder + - downloads: + - "test_1.txt" # files in the downloads folder + - "test_2.txt" + - root: + - passwords: # example of file with size and type + size: 69 # size in bytes + type: TXT # See FileType for list of available file types + +List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType` + ``users`` --------- diff --git a/docs/source/configuration/simulation/nodes/network_examples.rst b/docs/source/configuration/simulation/nodes/network_examples.rst index 4616139e..84ee4c60 100644 --- a/docs/source/configuration/simulation/nodes/network_examples.rst +++ b/docs/source/configuration/simulation/nodes/network_examples.rst @@ -1177,8 +1177,8 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv") some_tech_storage_srv.file_system.create_file(file_name="test.png") - pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"] - pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"] + pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"] + pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"] assert not pc_1_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1224,7 +1224,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. web_server: Server = network.get_node_by_hostname("some_tech_web_srv") - web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"] + web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"] assert not web_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1269,7 +1269,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc") - snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"] + snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"] assert snr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1294,7 +1294,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc") - jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"] + jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"] assert not jnr_dev_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, @@ -1337,7 +1337,7 @@ ACLs permitting or denying traffic as per our configured ACL rules. some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1") - hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"] + hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"] assert not hr_ftp_client.request_file( dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address, diff --git a/docs/source/configuration/simulation/nodes/router.rst b/docs/source/configuration/simulation/nodes/router.rst index 4b41784c..ee278a98 100644 --- a/docs/source/configuration/simulation/nodes/router.rst +++ b/docs/source/configuration/simulation/nodes/router.rst @@ -74,7 +74,7 @@ The subnet mask setting for the port. ``acl`` ------- -Sets up the ACL rules for the router. +Sets up the ACL rules for the router to apply to layer-3 traffic. These are not applied to layer-2 traffic such as ARP. e.g. @@ -85,10 +85,6 @@ e.g. ... acl: 1: - action: PERMIT - src_port: ARP - dst_port: ARP - 2: action: PERMIT protocol: ICMP diff --git a/docs/source/how_to_guides/extensible_agents.rst b/docs/source/how_to_guides/extensible_agents.rst index 1d765417..3236c21a 100644 --- a/docs/source/how_to_guides/extensible_agents.rst +++ b/docs/source/how_to_guides/extensible_agents.rst @@ -46,17 +46,13 @@ The core features that should be implemented in any new agent are detailed below - ref: example_green_agent team: GREEN - type: ExampleAgent + type: example-agent action_space: action_map: 0: action: do-nothing options: {} - reward_function: - reward_components: - - type: dummy - agent_settings: start_step: 25 frequency: 20 diff --git a/docs/source/how_to_guides/extensible_nodes.rst b/docs/source/how_to_guides/extensible_nodes.rst index 043d0f06..18d64ca8 100644 --- a/docs/source/how_to_guides/extensible_nodes.rst +++ b/docs/source/how_to_guides/extensible_nodes.rst @@ -26,9 +26,9 @@ class Router(NetworkNode, identifier="router"): """ Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces.""" SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } network_interfaces: Dict[str, RouterInterface] = {} @@ -52,4 +52,4 @@ class Router(NetworkNode, identifier="router"): Changes to YAML file. ===================== -While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed. +While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed. diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index f0437705..93fc2a9f 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -2,6 +2,8 @@ © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +.. _request_system: + Request System ************** diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 152b74b8..a6fe4070 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -97,19 +97,19 @@ we'll use the following Network that has a client, server, two switches, and a r network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1]) network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1]) -8. Add ACL rules on the Router to allow ARP and ICMP traffic. +8. Add an ACL rule on the Router to allow ICMP traffic. .. code-block:: python router_1.acl.add_rule( action=ACLAction.PERMIT, - src_port=Port["ARP"], - dst_port=Port["ARP"], + src_port=PORT_LOOKUP["ARP"], + dst_port=PORT_LOOKUP["ARP"], position=22 ) router_1.acl.add_rule( action=ACLAction.PERMIT, - protocol=IPProtocol["ICMP"], + protocol=PROTOCOL_LOOKUP["ICMP"], position=23 ) diff --git a/docs/source/simulation_components/network/nodes/wireless_router.rst b/docs/source/simulation_components/network/nodes/wireless_router.rst index d7207846..4078ffda 100644 --- a/docs/source/simulation_components/network/nodes/wireless_router.rst +++ b/docs/source/simulation_components/network/nodes/wireless_router.rst @@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality. network.connect(pc_a.network_interface[1], router_1.router_interface) # Configure Router 1 ACLs - router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22) - router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23) + router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22) + router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23) # Configure PC B pc_b = Computer( diff --git a/docs/source/simulation_components/system/applications/c2_suite.rst b/docs/source/simulation_components/system/applications/c2_suite.rst index 34175fc3..c780485a 100644 --- a/docs/source/simulation_components/system/applications/c2_suite.rst +++ b/docs/source/simulation_components/system/applications/c2_suite.rst @@ -183,7 +183,7 @@ Python # Example command: Installing and configuring Ransomware: ransomware_installation_command = { "commands": [ - ["software_manager","application","install","RansomwareScript"], + ["software_manager","application","install","ransomware-script"], ], "username": "admin", "password": "admin", diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 8e008504..3ddb8bca 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -77,7 +77,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) - data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") + data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() @@ -98,7 +98,7 @@ If not using the data manipulation bot manually, it needs to be used with a data type: red-database-corrupting-agent observation_space: - type: UC2RedObservation + type: uc2-red-observation #TODO what options: nodes: - node_name: client_1 diff --git a/docs/source/simulation_components/system/applications/database_client.rst b/docs/source/simulation_components/system/applications/database_client.rst index 7087dedf..472b504c 100644 --- a/docs/source/simulation_components/system/applications/database_client.rst +++ b/docs/source/simulation_components/system/applications/database_client.rst @@ -59,7 +59,7 @@ Python # install DatabaseClient client.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient") + database_client: DatabaseClient = client.software_manager.software.get("database-sclient") # Configure the DatabaseClient database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService diff --git a/docs/source/simulation_components/system/applications/ransomware_script.rst b/docs/source/simulation_components/system/applications/ransomware_script.rst index 192618fc..a8975f32 100644 --- a/docs/source/simulation_components/system/applications/ransomware_script.rst +++ b/docs/source/simulation_components/system/applications/ransomware_script.rst @@ -62,7 +62,7 @@ Python network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(RansomwareScript) - RansomwareScript: RansomwareScript = client_1.software_manager.software.get("RansomwareScript") + RansomwareScript: RansomwareScript = client_1.software_manager.software.get("ransomware-script") RansomwareScript.configure(server_ip_address=IPv4Address("192.168.1.14")) RansomwareScript.execute() diff --git a/docs/source/simulation_components/system/applications/web_browser.rst b/docs/source/simulation_components/system/applications/web_browser.rst index c04e60af..659caa09 100644 --- a/docs/source/simulation_components/system/applications/web_browser.rst +++ b/docs/source/simulation_components/system/applications/web_browser.rst @@ -61,7 +61,7 @@ The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`D # Install WebBrowser on computer computer.software_manager.install(WebBrowser) - web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser: WebBrowser = computer.software_manager.software.get("web-browser") web_browser.run() # configure the WebBrowser diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst index 961f2e45..c819a0f7 100644 --- a/docs/source/simulation_components/system/services/database_service.rst +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -66,7 +66,7 @@ Python # Install DatabaseService on server server.software_manager.install(DatabaseService) - db_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_service: DatabaseService = server.software_manager.software.get("database-service") db_service.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/dns_client.rst b/docs/source/simulation_components/system/services/dns_client.rst index 17a1ed25..40762bfc 100644 --- a/docs/source/simulation_components/system/services/dns_client.rst +++ b/docs/source/simulation_components/system/services/dns_client.rst @@ -56,7 +56,7 @@ Python # Install DNSClient on server server.software_manager.install(DNSClient) - dns_client: DNSClient = server.software_manager.software.get("DNSClient") + dns_client: DNSClient = server.software_manager.software.get("dns-client") dns_client.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/dns_server.rst b/docs/source/simulation_components/system/services/dns_server.rst index 633221d5..ca0e3691 100644 --- a/docs/source/simulation_components/system/services/dns_server.rst +++ b/docs/source/simulation_components/system/services/dns_server.rst @@ -53,7 +53,7 @@ Python # Install DNSServer on server server.software_manager.install(DNSServer) - dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server: DNSServer = server.software_manager.software.get("dns-server") dns_server.start() # configure DatabaseService diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst index d4375069..530b5aff 100644 --- a/docs/source/simulation_components/system/services/ftp_client.rst +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -60,7 +60,7 @@ Python # Install FTPClient on server server.software_manager.install(FTPClient) - ftp_client: FTPClient = server.software_manager.software.get("FTPClient") + ftp_client: FTPClient = server.software_manager.software.get("ftp-client") ftp_client.start() diff --git a/docs/source/simulation_components/system/services/ftp_server.rst b/docs/source/simulation_components/system/services/ftp_server.rst index a5ad32fe..20dd6707 100644 --- a/docs/source/simulation_components/system/services/ftp_server.rst +++ b/docs/source/simulation_components/system/services/ftp_server.rst @@ -55,7 +55,7 @@ Python # Install FTPServer on server server.software_manager.install(FTPServer) - ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server: FTPServer = server.software_manager.software.get("ftp-server") ftp_server.start() ftp_server.server_password = "test" diff --git a/docs/source/simulation_components/system/services/ntp_client.rst b/docs/source/simulation_components/system/services/ntp_client.rst index 8c011cad..5406d9fc 100644 --- a/docs/source/simulation_components/system/services/ntp_client.rst +++ b/docs/source/simulation_components/system/services/ntp_client.rst @@ -53,7 +53,7 @@ Python # Install NTPClient on server server.software_manager.install(NTPClient) - ntp_client: NTPClient = server.software_manager.software.get("NTPClient") + ntp_client: NTPClient = server.software_manager.software.get("ntp-client") ntp_client.start() ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10")) diff --git a/docs/source/simulation_components/system/services/ntp_server.rst b/docs/source/simulation_components/system/services/ntp_server.rst index c1d16d61..2c01dcaf 100644 --- a/docs/source/simulation_components/system/services/ntp_server.rst +++ b/docs/source/simulation_components/system/services/ntp_server.rst @@ -55,7 +55,7 @@ Python # Install NTPServer on server server.software_manager.install(NTPServer) - ntp_server: NTPServer = server.software_manager.software.get("NTPServer") + ntp_server: NTPServer = server.software_manager.software.get("ntp-server") ntp_server.start() diff --git a/docs/source/simulation_components/system/services/terminal.rst b/docs/source/simulation_components/system/services/terminal.rst index bc5cee48..5c9bad79 100644 --- a/docs/source/simulation_components/system/services/terminal.rst +++ b/docs/source/simulation_components/system/services/terminal.rst @@ -23,6 +23,14 @@ Key capabilities - Simulates common Terminal processes/commands. - Leverages the Service base class for install/uninstall, status tracking etc. +Usage +""""" + + - Pre-Installs on any `Node` component (with the exception of `Switches`). + - Terminal Clients connect, execute commands and disconnect from remote nodes. + - Ensures that users are logged in to the component before executing any commands. + - Service runs on SSH port 22 by default. + - Enables Agents to send commands both remotely and locally. Implementation """""""""""""" @@ -30,19 +38,112 @@ Implementation - Manages remote connections in a dictionary by session ID. - Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate. - Extends Service class. - - A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook. + +Command Format +^^^^^^^^^^^^^^ + +Terminals implement their commands through leveraging the pre-existing :ref:`request_system`. + +Due to this Terminals will only accept commands passed within the ``RequestFormat``. + +:py:class:`primaite.game.interface.RequestFormat` + +For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows: + +.. code-block:: yaml + + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False + +This is then loaded from yaml into a dictionary containing the terminal command: + +.. code-block:: python + + {"command":["file_system", "create", "file", "downloads", "cat.png", "False"]} + +Which is then passed to the ``Terminals`` Request Manager to be executed. + +Game Layer Usage (Agents) +======================== + +The below code examples demonstrate how to use terminal related actions in yaml files. + +yaml +"""" + +``node-send-local-command`` +""""""""""""""""""""""""""" + +Agents can execute local commands without needing to perform a separate remote login action (``node-session-remote-login``). + +.. code-block:: yaml + + ... + ... + action: node-send-local-command + options: + node_id: 0 + username: admin + password: admin + command: # Example command - Creates a file called 'cat.png' in the downloads folder. + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" -Usage -""""" +``node-session-remote-login`` +""""""""""""""""" - - Pre-Installs on all ``Nodes`` (with the exception of ``Switches``). - - Terminal Clients connect, execute commands and disconnect from remote nodes. - - Ensures that users are logged in to the component before executing any commands. - - Service runs on SSH port 22 by default. +Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts. + +.. code-block:: yaml + + ... + ... + action: node-session-remote-login + options: + node_id: 0 + username: admin + password: admin + remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh) + + +``node-send-remote-command`` +"""""""""""""""""""""""""""" + +After remotely logging into another host, an agent can use the ``node-send-remote-command`` to execute commands across the network remotely. + +.. code-block:: yaml + + ... + ... + action: node-send-remote-command + options: + node_id: 0 + remote_ip: 192.168.0.10 + command: + - "file_system" + - "create" + - "file" + - "downloads" + - "cat.png" + - "False" + + + +Simulation Layer Usage +====================== -Usage -===== The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node. @@ -65,7 +166,7 @@ Python operating_state=NodeOperatingState.ON, ) - terminal: Terminal = client.software_manager.software.get("Terminal") + terminal: Terminal = client.software_manager.software.get("terminal") Creating Remote Terminal Connection """"""""""""""""""""""""""""""""""" @@ -86,7 +187,7 @@ Creating Remote Terminal Connection node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -112,12 +213,12 @@ Executing a basic application install command node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") - term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) + term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "ransomware-script"]) @@ -140,7 +241,7 @@ Creating a folder on a remote node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") @@ -167,7 +268,7 @@ Disconnect from Remote Node node_b.power_on() network.connect(node_a.network_interface[1], node_b.network_interface[1]) - terminal_a: Terminal = node_a.software_manager.software.get("Terminal") + terminal_a: Terminal = node_a.software_manager.software.get("terminal") term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11") diff --git a/docs/source/simulation_components/system/services/web_server.rst b/docs/source/simulation_components/system/services/web_server.rst index bce42791..9d7f4d2f 100644 --- a/docs/source/simulation_components/system/services/web_server.rst +++ b/docs/source/simulation_components/system/services/web_server.rst @@ -56,7 +56,7 @@ Python # Install WebServer on server server.software_manager.install(WebServer) - web_server: WebServer = server.software_manager.software.get("WebServer") + web_server: WebServer = server.software_manager.software.get("web-server") web_server.start() Via Configuration diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index d28815bb..c2f3066b 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -30,7 +30,7 @@ See :ref:`Node Start up and Shut down` node.software_manager.install(WebServer) - web_server: WebServer = node.software_manager.software.get("WebServer") + web_server: WebServer = node.software_manager.software.get("web-server") assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install node.power_off() diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index 19639c21..b1b6ec12 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -1,6 +1,6 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from abc import ABC, abstractmethod -from typing import ClassVar, List, Optional, Union +from typing import ClassVar, List, Literal, Optional, Union from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"): """Action which performs an nmap network service recon (ping scan followed by port scan).""" - config: "NodeNetworkServiceReconAction.ConfigSchema" - class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): """Configuration schema for NodeNetworkServiceReconAction.""" @@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node- "show": config.show, }, ] + + +class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-add-user"] = "node-account-add-user" + node_name: str + username: str + password: str + is_admin: bool + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "add_user", + config.username, + config.password, + config.is_admin, + ] + + +class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-account-disable-user"] = "node-account-disable-user" + node_name: str + username: str + + @classmethod + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "user-manager", + "disable_user", + config.username, + ] + + +class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"): + class ConfigSchema(AbstractAction.ConfigSchema): + type: Literal["node-send-local-command"] = "node-send-local-command" + node_name: str + username: str + password: str + command: RequestFormat + + @staticmethod + def form_request(config: ConfigSchema) -> RequestFormat: + return [ + "network", + "node", + config.node_name, + "service", + "terminal", + "send_local_command", + config.username, + config.password, + {"command": config.command}, + ] diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index 58a8a555..63a45c5e 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC): class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"): """Action which performs a remote session login.""" - config: "NodeSessionsRemoteLoginAction.ConfigSchema" - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLoginAction.""" @@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no config.node_name, "service", "terminal", - "node-session-remote-login", + "node_session_remote_login", config.username, config.password, config.remote_ip, @@ -63,8 +61,6 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="node-session-remote-logoff"): """Action which performs a remote session logout.""" - config: "NodeSessionsRemoteLogoutAction.ConfigSchema" - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLogoutAction.""" @@ -78,14 +74,13 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="n return ["network", "node", config.node_name, "service", "terminal", config.verb, config.remote_ip] -class NodeAccountChangePasswordAction(NodeSessionAbstractAction, discriminator="node-account-change-password"): +class NodeAccountChangePasswordAction(AbstractAction, discriminator="node-account-change-password"): """Action which changes the password for a user.""" - config: "NodeAccountChangePasswordAction.ConfigSchema" - - class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for NodeAccountsChangePasswordAction.""" + node_name: str username: str current_password: str new_password: str diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index a55cd3ff..d06bd1d0 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING from gymnasium.core import ActType, ObsType +from prettytable import PrettyTable from pydantic import BaseModel, ConfigDict, Field from primaite.game.agent.actions import ActionManager @@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel): reward_info: Dict[str, Any] = {} + observation: Optional[ObsType] = None + """The observation space data for this step.""" + class AbstractAgent(BaseModel, ABC): """Base class for scripted and RL agents.""" @@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC): default_factory=lambda: ObservationManager.ConfigSchema() ) reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema()) + thresholds: Optional[Dict] = {} + # TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085) + """A dict containing the observation thresholds.""" config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema()) @@ -90,10 +97,42 @@ class AbstractAgent(BaseModel, ABC): def model_post_init(self, __context: Any) -> None: """Overwrite the default empty action, observation, and rewards with ones defined through the config.""" self.action_manager = ActionManager(config=self.config.action_space) + self.config.observation_space.options.thresholds = self.config.thresholds self.observation_manager = ObservationManager(config=self.config.observation_space) self.reward_function = RewardFunction(config=self.config.reward_function) return super().model_post_init(__context) + def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable: + """Update the given table with information from given AgentHistoryItem.""" + node, application = "unknown", "unknown" + if (node_id := item.parameters.get("node_id")) is not None: + node = self.action_manager.node_names[node_id] + if (application_id := item.parameters.get("application_id")) is not None: + application = self.action_manager.application_names[node_id][application_id] + if (application_name := item.parameters.get("application_name")) is not None: + application = application_name + table.add_row([item.timestep, item.action, node, application, item.response.status]) + return table + + def show_history(self, ignored_actions: Optional[list] = None): + """ + Print an agent action provided it's not the DONOTHING action. + + :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. + If not provided, defaults to ignore DONOTHING actions. + """ + if not ignored_actions: + ignored_actions = ["DONOTHING"] + table = PrettyTable() + table.field_names = ["Step", "Action", "Node", "Application", "Response"] + print(f"Actions for '{self.agent_name}':") + for item in self.history: + if item.action in ignored_actions: + pass + else: + table = self.add_agent_action(item=item, table=table) + print(table) + def update_observation(self, state: Dict) -> ObsType: """ Convert a state from the simulator into an observation for the agent using the observation space. @@ -140,12 +179,23 @@ class AbstractAgent(BaseModel, ABC): return request def process_action_response( - self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + self, + timestep: int, + action: str, + parameters: Dict[str, Any], + request: RequestFormat, + response: RequestResponse, + observation: ObsType, ) -> None: """Process the response from the most recent action.""" self.history.append( AgentHistoryItem( - timestep=timestep, action=action, parameters=parameters, request=request, response=response + timestep=timestep, + action=action, + parameters=parameters, + request=request, + response=response, + observation=observation, ) ) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index ed9dcd8f..a9e3a9aa 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"): file_system_requires_scan: Optional[bool] = None """If True, the file must be scanned to update the health state. Tf False, the true state is always shown.""" - def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None: + def __init__( + self, + where: WhereType, + include_num_access: bool, + file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, + ) -> None: """ Initialise a file observation instance. @@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"): if self.include_num_access: self.default_observation["num_access"] = 0 - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("file_access") is None: + self.low_file_access_threshold = 0 + self.med_file_access_threshold = 5 + self.high_file_access_threshold = 10 + else: + self._set_file_access_threshold( + thresholds=[ + thresholds.get("file_access")["low"], + thresholds.get("file_access")["medium"], + thresholds.get("file_access")["high"], + ] + ) + + def _set_file_access_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the file access threshold. + + :param: thresholds: The file access threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="file_access", + ): + self.low_file_access_threshold = thresholds[0] + self.med_file_access_threshold = thresholds[1] + self.high_file_access_threshold = thresholds[2] def _categorise_num_access(self, num_access: int) -> int: """ @@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"): :param num_access: Number of file accesses. :return: Bin number corresponding to the number of accesses. """ - if num_access > self.high_threshold: + if num_access > self.high_file_access_threshold: return 3 - elif num_access > self.med_threshold: + elif num_access > self.med_file_access_threshold: return 2 - elif num_access > self.low_threshold: + elif num_access > self.low_file_access_threshold: return 1 return 0 @@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"): where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) @@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): num_files: int, include_num_access: bool, file_system_requires_scan: bool, + thresholds: Optional[Dict] = {}, ) -> None: """ Initialise a folder observation instance. @@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): where=None, include_num_access=include_num_access, file_system_requires_scan=self.file_system_requires_scan, + thresholds=thresholds, ) ) while len(self.files) > num_files: @@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"): for file_config in config.files: file_config.include_num_access = config.include_num_access file_config.file_system_requires_scan = config.file_system_requires_scan + file_config.thresholds = config.thresholds files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls( @@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"): num_files=config.num_files, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 17bcb983..9b979063 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"): """ If True, files and folders must be scanned to update the health state. If False, true state is always shown. """ - include_users: Optional[bool] = None + services_requires_scan: Optional[bool] = None + """ + If True, services must be scanned to update the health state. If False, true state is always shown. + """ + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + include_users: Optional[bool] = True """If True, report user session information.""" def __init__( @@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"): monitored_traffic: Optional[Dict], include_num_access: bool, file_system_requires_scan: bool, + services_requires_scan: bool, + applications_requires_scan: bool, include_users: bool, ) -> None: """ @@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"): :param file_system_requires_scan: If True, the files and folders must be scanned to update the health state. If False, the true state is always shown. :type file_system_requires_scan: bool + :param services_requires_scan: If True, services must be scanned to update the health state. + If False, the true state is always shown. + :type services_requires_scan: bool + :param applications_requires_scan: If True, applications must be scanned to update the health state. + If False, the true state is always shown. + :type applications_requires_scan: bool :param include_users: If True, report user session information. :type include_users: bool """ @@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"): # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: - self.services.append(ServiceObservation(where=None)) + self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan)) while len(self.services) > num_services: truncated_service = self.services.pop() msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" @@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"): self.applications: List[ApplicationObservation] = applications while len(self.applications) < num_applications: - self.applications.append(ApplicationObservation(where=None)) + self.applications.append( + ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan) + ) while len(self.applications) > num_applications: truncated_application = self.applications.pop() msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" @@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"): self.nics: List[NICObservation] = network_interfaces while len(self.nics) < num_nics: - self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic)) + self.nics.append( + NICObservation( + where=None, + include_nmne=include_nmne, + monitored_traffic=monitored_traffic, + ) + ) while len(self.nics) > num_nics: truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" @@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"): folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files folder_config.file_system_requires_scan = config.file_system_requires_scan + folder_config.thresholds = config.thresholds for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne + nic_config.thresholds = config.thresholds + for service_config in config.services: + service_config.services_requires_scan = config.services_requires_scan + for application_config in config.applications: + application_config.applications_requires_scan = config.applications_requires_scan + application_config.thresholds = config.thresholds services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] @@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"): count = 1 while len(nics) < config.num_nics: nic_config = NICObservation.ConfigSchema( - nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic + nic_num=count, + include_nmne=config.include_nmne, + monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) nics.append(NICObservation.from_config(config=nic_config, parent_where=where)) count += 1 @@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"): monitored_traffic=config.monitored_traffic, include_num_access=config.include_num_access, file_system_requires_scan=config.file_system_requires_scan, + services_requires_scan=config.services_requires_scan, + applications_requires_scan=config.applications_requires_scan, include_users=config.include_users, ) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 1aa6470d..8faeb906 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,13 +1,14 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict, List, Optional +from typing import ClassVar, Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import NMNEConfig from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.port import Port @@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port class NICObservation(AbstractObservation, discriminator="network-interface"): """Status information about a network interface within the simulation environment.""" + capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne + "A Boolean specifying whether malicious network events should be captured." + class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NICObservation.""" @@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None """A dict containing which traffic types are to be included in the observation.""" - def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None: + def __init__( + self, + where: WhereType, + include_nmne: bool, + monitored_traffic: Optional[Dict] = None, + thresholds: Dict = {}, + ) -> None: """ Initialise a network interface observation instance. @@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): self.nmne_inbound_last_step: int = 0 self.nmne_outbound_last_step: int = 0 - # TODO: allow these to be configured in yaml - self.high_nmne_threshold = 10 - self.med_nmne_threshold = 5 - self.low_nmne_threshold = 0 + if thresholds.get("nmne") is None: + self.low_nmne_threshold = 0 + self.med_nmne_threshold = 5 + self.high_nmne_threshold = 10 + else: + self._set_nmne_threshold( + thresholds=[ + thresholds.get("nmne")["low"], + thresholds.get("nmne")["medium"], + thresholds.get("nmne")["high"], + ] + ) self.monitored_traffic = monitored_traffic if self.monitored_traffic: @@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): bandwidth_utilisation = traffic_value / nic_max_bandwidth return int(bandwidth_utilisation * 9) + 1 + def _set_nmne_threshold(self, thresholds: List[int]): + """ + Method that validates and then sets the NMNE threshold. + + :param: thresholds: The NMNE threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=thresholds, + threshold_identifier="nmne", + ): + self.low_nmne_threshold = thresholds[0] + self.med_nmne_threshold = thresholds[1] + self.high_nmne_threshold = thresholds[2] + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): """ nic_state = access_from_nested_dict(state, self.where) - if nic_state is NOT_PRESENT_IN_STATE: + if nic_state is NOT_PRESENT_IN_STATE or self.where is None: return self.default_observation obs = {"nic_status": 1 if nic_state["enabled"] else 2} @@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): for port in self.monitored_traffic[protocol]: obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0} - if self.include_nmne: + if self.capture_nmne and self.include_nmne: obs.update({"NMNE": {}}) direction_dict = nic_state["nmne"].get("direction", {}) inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) @@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"): where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic, + thresholds=config.thresholds, ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 3a3283a2..260fac68 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" file_system_requires_scan: bool = True - """If True, the folder must be scanned to update the health state. Tf False, the true state is always shown.""" + """If True, the folder must be scanned to update the health state. If False, the true state is always shown.""" + services_requires_scan: bool = True + """If True, the services must be scanned to update the health state. + If False, the true state is always shown.""" + applications_requires_scan: bool = True + """If True, the applications must be scanned to update the health state. + If False, the true state is always shown.""" include_users: Optional[bool] = True """If True, report user session information.""" num_ports: Optional[int] = None @@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): host_config.include_num_access = config.include_num_access if host_config.file_system_requires_scan is None: host_config.file_system_requires_scan = config.file_system_requires_scan + if host_config.services_requires_scan is None: + host_config.services_requires_scan = config.services_requires_scan + if host_config.applications_requires_scan is None: + host_config.applications_requires_scan = config.applications_requires_scan if host_config.include_users is None: host_config.include_users = config.include_users + if not host_config.thresholds: + host_config.thresholds = config.thresholds for router_config in config.routers: if router_config.num_ports is None: @@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): router_config.num_rules = config.num_rules if router_config.include_users is None: router_config.include_users = config.include_users + if not router_config.thresholds: + router_config.thresholds = config.thresholds for firewall_config in config.firewalls: if firewall_config.ip_list is None: @@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"): firewall_config.num_rules = config.num_rules if firewall_config.include_users is None: firewall_config.include_users = config.include_users + if not firewall_config.thresholds: + firewall_config.thresholds = config.thresholds hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 032435b8..e8cb18aa 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options)) + obs_instance = obs_class.from_config( + config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds) + ) instances[component.label] = obs_instance return cls(components=instances) @@ -242,8 +244,5 @@ class ObservationManager(BaseModel): """ if config is None: return cls(NullObservation()) - obs_type = config["type"] - obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"])) - obs_manager = cls(observation) + obs_manager = cls(config=config) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index da81d2ad..8558b75c 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union from gymnasium import spaces from gymnasium.core import ObsType @@ -19,6 +19,9 @@ class AbstractObservation(ABC): class ConfigSchema(ABC, BaseModel): """Config schema for observations.""" + thresholds: Optional[Dict] = {} + """A dict containing the observation thresholds.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -69,3 +72,34 @@ class AbstractObservation(ABC): def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() + + def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool: + """ + Method that checks if the thresholds are non overlapping and in the correct (ascending) order. + + Pass in the thresholds from low to high e.g. + thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold] + + Throws an error if the threshold is not valid + + :param: thresholds: List of thresholds in ascending order. + :type: List[int] + :param: threshold_identifier: The name of the threshold option. + :type: Optional[str] + + :returns: bool + """ + if thresholds is None or len(thresholds) < 2: + raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}") + for idx in range(1, len(thresholds)): + if not isinstance(thresholds[idx], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + if not isinstance(thresholds[idx - 1], int): + raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.") + + if thresholds[idx] <= thresholds[idx - 1]: + raise Exception( + f"{threshold_identifier} threshold ({thresholds[idx - 1]}) " + f"is greater than or equal to ({thresholds[idx]}.)" + ) + return True diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 07ec1abf..dac6b362 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,7 +1,7 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK from __future__ import annotations -from typing import Dict +from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType @@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"): service_name: str """Name of the service, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + services_requires_scan: Optional[bool] = None + """If True, services must be scanned to update the health state. If False, true state is always shown.""" + + def __init__(self, where: WhereType, services_requires_scan: bool) -> None: """ Initialise a service observation instance. @@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"): :type where: WhereType """ self.where = where + self.services_requires_scan = services_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0} def observe(self, state: Dict) -> ObsType: @@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"): return self.default_observation return { "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], + "health_status": service_state["health_state_visible"] + if self.services_requires_scan + else service_state["health_state_actual"], } @property @@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"): :return: Constructed service observation instance. :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", config.service_name]) + return cls( + where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan + ) class ApplicationObservation(AbstractObservation, discriminator="application"): @@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): application_name: str """Name of the application, used for querying simulation state dictionary""" - def __init__(self, where: WhereType) -> None: + applications_requires_scan: Optional[bool] = None + """ + If True, applications must be scanned to update the health state. If False, true state is always shown. + """ + + def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None: """ Initialise an application observation instance. @@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): :type where: WhereType """ self.where = where + self.applications_requires_scan = applications_requires_scan self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - # TODO: allow these to be configured in yaml - self.high_threshold = 10 - self.med_threshold = 5 - self.low_threshold = 0 + if thresholds.get("app_executions") is None: + self.low_app_execution_threshold = 0 + self.med_app_execution_threshold = 5 + self.high_app_execution_threshold = 10 + else: + self._set_application_execution_thresholds( + thresholds=[ + thresholds.get("app_executions")["low"], + thresholds.get("app_executions")["medium"], + thresholds.get("app_executions")["high"], + ] + ) + + def _set_application_execution_thresholds(self, thresholds: List[int]): + """ + Method that validates and then sets the application execution threshold. + + :param: thresholds: The application execution threshold to validate and set. + """ + if self._validate_thresholds( + thresholds=[ + thresholds[0], + thresholds[1], + thresholds[2], + ], + threshold_identifier="app_executions", + ): + self.low_app_execution_threshold = thresholds[0] + self.med_app_execution_threshold = thresholds[1] + self.high_app_execution_threshold = thresholds[2] def _categorise_num_executions(self, num_executions: int) -> int: """ - Represent number of file accesses as a categorical variable. + Represent number of application executions as a categorical variable. - :param num_access: Number of file accesses. + :param num_access: Number of application executions. :return: Bin number corresponding to the number of accesses. """ - if num_executions > self.high_threshold: + if num_executions > self.high_app_execution_threshold: return 3 - elif num_executions > self.med_threshold: + elif num_executions > self.med_app_execution_threshold: return 2 - elif num_executions > self.low_threshold: + elif num_executions > self.low_app_execution_threshold: return 1 return 0 @@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): return self.default_observation return { "operating_status": application_state["operating_state"], - "health_status": application_state["health_state_visible"], + "health_status": application_state["health_state_visible"] + if self.applications_requires_scan + else application_state["health_state_actual"], "num_executions": self._categorise_num_executions(application_state["num_executions"]), } @@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"): :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["applications", config.application_name]) + return cls( + where=parent_where + ["applications", config.application_name], + applications_requires_scan=config.applications_requires_scan, + thresholds=config.thresholds, + ) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 1427776e..a3b77ec3 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict from primaite import DEFAULT_BANDWIDTH, getLogger from primaite.game.agent.interface import AbstractAgent, ProxyAgent +from primaite.game.agent.observations import NICObservation from primaite.game.agent.rewards import SharedReward from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT @@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP _LOGGER = getLogger(__name__) SERVICE_TYPES_MAPPING = { - "DNSClient": DNSClient, - "DNSServer": DNSServer, - "DatabaseService": DatabaseService, - "WebServer": WebServer, - "FTPClient": FTPClient, - "FTPServer": FTPServer, - "NTPClient": NTPClient, - "NTPServer": NTPServer, - "Terminal": Terminal, + "dns-client": DNSClient, + "dns-server": DNSServer, + "database-service": DatabaseService, + "web-server": WebServer, + "ftp-client": FTPClient, + "ftp-server": FTPServer, + "ntp-client": NTPClient, + "ntp-server": NTPServer, + "terminal": Terminal, } """List of available services that can be installed on nodes in the PrimAITE Simulation.""" @@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel): seed: int = None """Random number seed for RNGs.""" + generate_seed_value: bool = False + """Internally generated seed value.""" max_episode_length: int = 256 """Maximum number of episodes for the PrimAITE game.""" ports: List[Port] @@ -175,6 +178,7 @@ class PrimaiteGame: parameters=parameters, request=request, response=response, + observation=obs, ) def pre_timestep(self) -> None: @@ -263,6 +267,7 @@ class PrimaiteGame: node_sets_cfg = network_config.get("node_sets", []) # Set the NMNE capture config NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {})) + NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne for node_cfg in nodes_cfg: n_type = node_cfg["type"] @@ -293,6 +298,7 @@ class PrimaiteGame: if "users" in node_cfg and new_node.software_manager.software.get("user-manager"): user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa + for user_cfg in node_cfg["users"]: user_manager.add_user(**user_cfg, bypass_can_perform_action=True) @@ -407,6 +413,7 @@ class PrimaiteGame: agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: + agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds} new_agent = AbstractAgent.from_config(agent_cfg) game.agents[agent_cfg["ref"]] = new_agent if isinstance(new_agent, ProxyAgent): diff --git a/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb index ef4e75dd..52499ea6 100644 --- a/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb @@ -50,40 +50,22 @@ "custom_c2_agent = \"\"\"\n", " - ref: CustomC2Agent\n", " team: RED\n", - " type: ProxyAgent\n", + " type: proxy-a.gent\n", "\n", " action_space:\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications:\n", - " - application_name: C2Beacon\n", - " - node_name: client_1\n", - " applications:\n", - " - application_name: C2Server\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.1.21\n", - " - 192.168.1.14\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " action_map:\n", " 0:\n", " action: do_nothing\n", " options: {}\n", " 1:\n", - " action: node_application_install\n", + " action: node-application-install\n", " options:\n", - " node_id: 0\n", - " application_name: C2Beacon\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 2:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " config:\n", " c2_server_ip_address: 192.168.10.21\n", " keep_alive_frequency:\n", @@ -92,10 +74,10 @@ " 3:\n", " action: node_application_execute\n", " options:\n", - " node_id: 0\n", - " application_id: 0\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 4:\n", - " action: c2_server_terminal_command\n", + " action: c2-server-terminal-command\n", " options:\n", " node_id: 1\n", " ip_address:\n", @@ -111,14 +93,14 @@ " 5:\n", " action: c2-server-ransomware-configure\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " config:\n", " server_ip_address: 192.168.1.14\n", " payload: ENCRYPT\n", " 6:\n", - " action: c2_server_data_exfiltrate\n", + " action: c2-server-data-exfiltrate\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " target_file_name: \"database.db\"\n", " target_folder_name: \"database\"\n", " exfiltration_folder_name: \"spoils\"\n", @@ -128,31 +110,27 @@ " password: admin\n", "\n", " 7:\n", - " action: c2_server_ransomware_launch\n", + " action: c2-server-ransomware-launch\n", " options:\n", - " node_id: 1\n", + " node_name: client_1\n", " 8:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " config:\n", " c2_server_ip_address: 192.168.10.21\n", " keep_alive_frequency: 10\n", " masquerade_protocol: TCP\n", " masquerade_port: DNS\n", " 9:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " config:\n", " c2_server_ip_address: 192.168.10.22\n", " keep_alive_frequency:\n", " masquerade_protocol:\n", " masquerade_port:\n", - "\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", "\"\"\"\n", "c2_agent_yaml = yaml.safe_load(custom_c2_agent)" ] @@ -225,7 +203,7 @@ " nodes: # Node List\n", " - node_name: web_server\n", " applications: \n", - " - application_name: C2Beacon\n", + " - application_name: c2-beacon\n", " ...\n", " ...\n", " action_map:\n", @@ -233,7 +211,7 @@ " action: node_application_install \n", " options:\n", " node_id: 0 # Index 0 at the node list.\n", - " application_name: C2Beacon\n", + " application_name: c2-beacon\n", "```" ] }, @@ -268,7 +246,7 @@ " action_map:\n", " ...\n", " 2:\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", " node_id: 0 # Node Index\n", " config: # Further information about these config options can be found at the bottom of this notebook.\n", @@ -286,7 +264,7 @@ "outputs": [], "source": [ "env.step(2)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "web_server.software_manager.show()\n", "c2_beacon.show()" ] @@ -307,13 +285,13 @@ " nodes: # Node List\n", " - node_name: web_server\n", " applications: \n", - " - application_name: C2Beacon\n", + " - application_name: c2-beacon\n", " ...\n", " ...\n", " action_map:\n", " ...\n", " 3:\n", - " action: node_application_execute\n", + " action: node-application-execute\n", " options:\n", " node_id: 0\n", " application_id: 0\n", @@ -374,11 +352,11 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 4:\n", - " action: C2_SERVER_TERMINAL_COMMAND\n", + " action: c2-server-terminal-command\n", " options:\n", " node_id: 1\n", " ip_address:\n", @@ -431,7 +409,7 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 5:\n", @@ -459,7 +437,7 @@ "metadata": {}, "outputs": [], "source": [ - "ransomware_script: RansomwareScript = web_server.software_manager.software[\"RansomwareScript\"]\n", + "ransomware_script: RansomwareScript = web_server.software_manager.software[\"ransomware-script\"]\n", "web_server.software_manager.show()\n", "ransomware_script.show()" ] @@ -483,11 +461,11 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 6:\n", - " action: c2_server_data_exfiltrate\n", + " action: c2-server-data-exfiltrate\n", " options:\n", " node_id: 1\n", " target_file_name: \"database.db\"\n", @@ -549,11 +527,11 @@ " ...\n", " - node_name: client_1\n", " applications: \n", - " - application_name: C2Server\n", + " - application_name: c2-server\n", " ...\n", " action_map:\n", " 7:\n", - " action: c2_server_ransomware_launch\n", + " action: c2-server-ransomware-launch\n", " options:\n", " node_id: 1\n", "```\n" @@ -598,20 +576,20 @@ "custom_blue_agent_yaml = \"\"\"\n", " - ref: defender\n", " team: BLUE\n", - " type: ProxyAgent\n", + " type: proxy-agent\n", "\n", " observation_space:\n", - " type: CUSTOM\n", + " type: custom\n", " options:\n", " components:\n", - " - type: NODES\n", + " - type: nodes\n", " label: NODES\n", " options:\n", " hosts:\n", " - hostname: web_server\n", " applications:\n", - " - application_name: C2Beacon\n", - " - application_name: RansomwareScript\n", + " - application_name: c2-beacon\n", + " - application_name: ransomware-script\n", " folders:\n", " - folder_name: exfiltration_folder\n", " files:\n", @@ -661,7 +639,7 @@ " - UDP\n", " num_rules: 10\n", "\n", - " - type: LINKS\n", + " - type: links\n", " label: LINKS\n", " options:\n", " link_references:\n", @@ -675,7 +653,7 @@ " - switch_2:eth-1<->client_1:eth-1\n", " - switch_2:eth-2<->client_2:eth-1\n", " - switch_2:eth-7<->security_suite:eth-2\n", - " - type: \"NONE\"\n", + " - type: \"none\"\n", " label: ICS\n", " options: {}\n", "\n", @@ -685,16 +663,16 @@ " action: do_nothing\n", " options: {}\n", " 1:\n", - " action: node_application_remove\n", + " action: node-application-remove\n", " options:\n", - " node_id: 0\n", + " node_name: web-server\n", " application_name: C2Beacon\n", " 2:\n", - " action: node_shutdown\n", + " action: node-shutdown\n", " options:\n", - " node_id: 0\n", + " node_name: web-server\n", " 3:\n", - " action: router_acl_add_rule\n", + " action: router-acl-add-rule\n", " options:\n", " target_router: router_1\n", " position: 1\n", @@ -707,36 +685,6 @@ " source_wildcard_id: 0\n", " dest_wildcard_id: 0\n", "\n", - "\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications:\n", - " - application_name: C2Beacon\n", - "\n", - " - node_name: database_server\n", - " folders:\n", - " - folder_name: database\n", - " files:\n", - " - file_name: database.db\n", - " services:\n", - " - service_name: DatabaseService\n", - " - node_name: router_1\n", - "\n", - " max_folders_per_node: 2\n", - " max_files_per_folder: 2\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.10.21\n", - " - 192.168.1.12\n", - " wildcard_list:\n", - " - 0.0.0.1\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", - "\n", " agent_settings:\n", " flatten_obs: False\n", "\"\"\"\n", @@ -875,7 +823,7 @@ "outputs": [], "source": [ "# Installing RansomwareScript via C2 Terminal Commands\n", - "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n", + "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)\n" @@ -1034,11 +982,11 @@ " web_server: Server = given_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n", "\n", " client_1.software_manager.install(C2Server)\n", - " c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + " c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", " c2_server.run()\n", "\n", " web_server.software_manager.install(C2Beacon)\n", - " c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + " c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", " c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\")\n", " c2_beacon.establish()\n", "\n", @@ -1132,11 +1080,11 @@ "outputs": [], "source": [ "# Attempting to install the C2 RansomwareScript\n", - "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n", + "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)" ] }, @@ -1220,11 +1168,11 @@ "outputs": [], "source": [ "# Attempting to install the C2 RansomwareScript\n", - "ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n", + "ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"ransomware-script\"],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", - "c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)" ] }, @@ -1345,7 +1293,7 @@ "metadata": {}, "outputs": [], "source": [ - "database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database_server\")\n", + "database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database-server\")\n", "database_server.software_manager.file_system.show(full=True)" ] }, @@ -1391,7 +1339,7 @@ "\n", "``` YAML\n", "...\n", - " action: configure_c2_beacon\n", + " action: configure-c2-beacon\n", " options:\n", " node_id: 0\n", " config:\n", @@ -1446,16 +1394,16 @@ "source": [ "web_server: Server = c2_config_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n", "web_server.software_manager.install(C2Beacon)\n", - "c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n", + "c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n", "\n", "client_1: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", "client_1.software_manager.install(C2Server)\n", - "c2_server_1: C2Server = client_1.software_manager.software[\"C2Server\"]\n", + "c2_server_1: C2Server = client_1.software_manager.software[\"c2-server\"]\n", "c2_server_1.run()\n", "\n", "client_2: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_2\")\n", "client_2.software_manager.install(C2Server)\n", - "c2_server_2: C2Server = client_2.software_manager.software[\"C2Server\"]\n", + "c2_server_2: C2Server = client_2.software_manager.software[\"c2-server\"]\n", "c2_server_2.run()" ] }, @@ -1759,6 +1707,16 @@ "\n", "display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "env.game.agents[\"CustomC2Agent\"].show_history()" + ] } ], "metadata": { diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb index 756fc44f..2dbf750e 100644 --- a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -47,7 +47,7 @@ "source": [ "def make_cfg_have_flat_obs(cfg):\n", " for agent in cfg['agents']:\n", - " if agent['type'] == \"ProxyAgent\":\n", + " if agent['type'] == \"proxy-agent\":\n", " agent['agent_settings']['flatten_obs'] = False" ] }, @@ -76,9 +76,9 @@ " # parse the info dict form step output and write out what the red agent is doing\n", " red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", " red_action = red_info.action\n", - " if red_action == 'do_nothing':\n", + " if red_action == 'do-nothing':\n", " red_str = 'DO NOTHING'\n", - " elif red_action == 'node_application_execute':\n", + " elif red_action == 'node-application-execute':\n", " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", " red_str = f\"ATTACK from {client}\"\n", " return red_str" @@ -147,36 +147,14 @@ "```yaml\n", " - ref: data_manipulation_attacker # name of agent\n", " team: RED # not used, just for human reference\n", - " type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n", + " type: red-database-corrupting-agent # type of agent - this lets primaite know which agent class to use\n", "\n", " # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n", " observation_space:\n", - " type: UC2RedObservation\n", + " type: uc2-red-observation # TODO: what\n", " options:\n", " nodes: {}\n", "\n", - " action_space:\n", - " \n", - " # The agent has access to the DataManipulationBoth on clients 1 and 2.\n", - " options:\n", - " nodes:\n", - " - node_name: client_1 # The network should have a node called client_1\n", - " applications:\n", - " - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n", - " - node_name: client_2 # The network should have a node called client_2\n", - " applications:\n", - " - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n", - "\n", - " # not important\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 1\n", - "\n", - " # red agent does not need a reward function\n", - " reward_function:\n", - " reward_components:\n", - " - type: DUMMY\n", - "\n", " # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n", " agent_settings:\n", " start_settings:\n", @@ -211,15 +189,13 @@ " \n", " # \n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 0.8 # Probability that port scan is successful\n", " data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n", " payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n", " server_ip: 192.168.1.14 # IP address of server hosting the database\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n", + " - type: database-client # Database client must be installed in order for DataManipulationBot to function\n", " options:\n", " db_server_ip: 192.168.1.14 # IP address of server hosting the database\n", "```" @@ -354,19 +330,16 @@ "# Make attack always succeed.\n", "change = yaml.safe_load(\"\"\"\n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 1.0\n", " data_manipulation_p_of_success: 1.0\n", " payload: \"DELETE\"\n", " server_ip: 192.168.1.14\n", - " - ref: client_1_web_browser\n", - " type: WebBrowser\n", + " - type: web-browser\n", " options:\n", " target_url: http://arcd.com/users/\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient\n", + " - type: database-client\n", " options:\n", " db_server_ip: 192.168.1.14\n", "\"\"\")\n", @@ -399,19 +372,16 @@ "# Make attack always fail.\n", "change = yaml.safe_load(\"\"\"\n", " applications:\n", - " - ref: data_manipulation_bot\n", - " type: DataManipulationBot\n", + " - type: data-manipulation-bot\n", " options:\n", " port_scan_p_of_success: 0.0\n", " data_manipulation_p_of_success: 0.0\n", " payload: \"DELETE\"\n", " server_ip: 192.168.1.14\n", - " - ref: client_1_web_browser\n", - " type: WebBrowser\n", + " - type: web-browser\n", " options:\n", " target_url: http://arcd.com/users/\n", - " - ref: client_1_database_client\n", - " type: DatabaseClient\n", + " - type: database-client\n", " options:\n", " db_server_ip: 192.168.1.14\n", "\"\"\")\n", diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index dbc6f0c1..2070e03c 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -684,6 +684,15 @@ " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.game.agents[\"data_manipulation_attacker\"].show_history()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -717,7 +726,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb index f8691d7d..58573ac6 100644 --- a/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb +++ b/src/primaite/notebooks/Getting-Information-Out-Of-PrimAITE.ipynb @@ -153,6 +153,49 @@ "PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n", "PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Viewing Agent history\n", + "\n", + "It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "# Run the training session to generate some resultant data.\n", + "for i in range(100):\n", + " env.step(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "attacker = env.game.agents[\"data_manipulation_attacker\"]\n", + "\n", + "attacker.show_history()" + ] } ], "metadata": { @@ -171,7 +214,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb new file mode 100644 index 00000000..8f8ec24b --- /dev/null +++ b/src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb @@ -0,0 +1,479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PrimAITE Developer mode\n", + "\n", + "PrimAITE has built in developer tools.\n", + "\n", + "The dev-mode is designed to help make the development of PrimAITE easier.\n", + "\n", + "`NOTE: For the purposes of the notebook, the commands are preceeded by \"!\". When running the commands, run it without the \"!\".`\n", + "\n", + "To display the available dev-mode options, run the command below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode --help" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save the current PRIMAITE_CONFIG to restore after the notebook runs\n", + "\n", + "from primaite import PRIMAITE_CONFIG\n", + "\n", + "temp_config = PRIMAITE_CONFIG.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dev mode options" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### enable\n", + "\n", + "Enables the dev mode for PrimAITE.\n", + "\n", + "This will enable the developer mode for PrimAITE.\n", + "\n", + "By default, when developer mode is enabled, session logs will be generated in the PRIMAITE_ROOT/sessions folder unless configured to be generated in another location." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode enable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### disable\n", + "\n", + "Disables the dev mode for PrimAITE.\n", + "\n", + "This will disable the developer mode for PrimAITE." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode disable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### show\n", + "\n", + "Shows if PrimAITE is running in dev mode or production mode.\n", + "\n", + "The command will also show the developer mode configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode show" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### config\n", + "\n", + "Configure the PrimAITE developer mode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --help" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### path\n", + "\n", + "Set the path where generated session files will be output.\n", + "\n", + "By default, this value will be in PRIMAITE_ROOT/sessions.\n", + "\n", + "To reset the path to default, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config path -root\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config path --default" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --sys-log-level or -slevel\n", + "\n", + "Set the system log level.\n", + "\n", + "This will override the system log level in configurations and will make PrimAITE include the set log level and above.\n", + "\n", + "Available options are:\n", + "- `DEBUG`\n", + "- `INFO`\n", + "- `WARNING`\n", + "- `ERROR`\n", + "- `CRITICAL`\n", + "\n", + "Default value is `DEBUG`\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --sys-log-level DEBUG\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -slevel DEBUG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --agent-log-level or -alevel\n", + "\n", + "Set the agent log level.\n", + "\n", + "This will override the agent log level in configurations and will make PrimAITE include the set log level and above.\n", + "\n", + "Available options are:\n", + "- `DEBUG`\n", + "- `INFO`\n", + "- `WARNING`\n", + "- `ERROR`\n", + "- `CRITICAL`\n", + "\n", + "Default value is `DEBUG`\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --agent-log-level DEBUG\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -alevel DEBUG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-sys-logs or -sys\n", + "\n", + "If enabled, developer mode will output system logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-sys-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting sys logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-sys-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nsys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-agent-logs or -agent\n", + "\n", + "If enabled, developer mode will output agent action logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-agent-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting agent action logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-agent-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nagent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-pcap-logs or -pcap\n", + "\n", + "If enabled, developer mode will output PCAP logs.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-pcap-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -pcap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable outputting PCAP logs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-pcap-logs\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -npcap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### --output-to-terminal or -t\n", + "\n", + "If enabled, developer mode will output logs to the terminal.\n", + "\n", + "Example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --output-to-terminal\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -t" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To disable terminal outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config --no-terminal\n", + "\n", + "# or\n", + "\n", + "!primaite dev-mode config -nt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combining commands\n", + "\n", + "It is possible to combine commands to set the configuration.\n", + "\n", + "This saves having to enter multiple commands and allows for a much more efficient setting of PrimAITE developer mode configurations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example of setting system log level and enabling the system logging:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config -slevel WARNING -sys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another example where the system log and agent action log levels are set and enabled and should be printed to terminal:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!primaite dev-mode config -slevel ERROR -sys -alevel ERROR -agent -t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Restore PRIMAITE_CONFIG\n", + "from primaite.utils.cli.primaite_config_utils import update_primaite_application_config\n", + "\n", + "\n", + "global PRIMAITE_CONFIG\n", + "PRIMAITE_CONFIG[\"developer_mode\"] = temp_config[\"developer_mode\"]\n", + "update_primaite_application_config(config=PRIMAITE_CONFIG)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/notebooks/Requests-and-Responses.ipynb b/src/primaite/notebooks/Requests-and-Responses.ipynb index 83aed07c..9260b29d 100644 --- a/src/primaite/notebooks/Requests-and-Responses.ipynb +++ b/src/primaite/notebooks/Requests-and-Responses.ipynb @@ -114,7 +114,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(f\"DNS Client state: {client.software_manager.software.get('DNSClient').operating_state.name}\")" + "print(f\"DNS Client state: {client.software_manager.software.get('dns-client').operating_state.name}\")" ] }, { diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 9aa4e96a..7c94d432 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -9,6 +9,13 @@ "© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simulation Layer Implementation." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -67,9 +74,9 @@ "source": [ "network: Network = basic_network()\n", "computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n", - "terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n", + "terminal_a: Terminal = computer_a.software_manager.software.get(\"terminal\")\n", "computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n", - "terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")" + "terminal_b: Terminal = computer_b.software_manager.software.get(\"terminal\")" ] }, { @@ -121,7 +128,7 @@ "metadata": {}, "outputs": [], "source": [ - "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])" + "term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"ransomware-script\"])" ] }, { @@ -169,6 +176,22 @@ "computer_b.file_system.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(terminal_a.last_response)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -207,6 +230,263 @@ "source": [ "computer_b.user_session_manager.show(include_historic=True, include_session_id=True)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Game Layer Implementation\n", + "\n", + "This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n", + "\n", + "The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n", + "\n", + "\n", + "| Game Layer Action | Simulation Layer |\n", + "|-----------------------------------|--------------------------|\n", + "| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n", + "| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n", + "| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Game Layer Setup\n", + "\n", + "Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n", + "\n", + "If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from primaite.session.environment import PrimaiteGymEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "custom_terminal_agent = \"\"\"\n", + " - ref: CustomC2Agent\n", + " team: RED\n", + " type: proxy-agent\n", + " observation_space: null\n", + " action_space:\n", + " options:\n", + " nodes:\n", + " - node_name: client_1\n", + " max_folders_per_node: 1\n", + " max_files_per_folder: 1\n", + " max_services_per_node: 2\n", + " max_nics_per_node: 8\n", + " max_acl_rules: 10\n", + " ip_list:\n", + " - 192.168.1.21\n", + " - 192.168.1.14\n", + " wildcard_list:\n", + " - 0.0.0.1\n", + " action_map:\n", + " 0:\n", + " action: do-nothing\n", + " options: {}\n", + " 1:\n", + " action: node-send-local-command\n", + " options:\n", + " node_name: client_1\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + " 2:\n", + " action: node-session-remote-login\n", + " options:\n", + " node_name: client_1\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22\n", + " 3:\n", + " action: node-send-remote-command\n", + " options:\n", + " node_name: client_1\n", + " remote_ip: 192.168.10.22\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "\"\"\"\n", + "custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path()) as f:\n", + " cfg = yaml.safe_load(f)\n", + " # removing all agents & adding the custom agent.\n", + " cfg['agents'] = {}\n", + " cfg['agents'] = custom_terminal_agent_yaml\n", + "\n", + "env = PrimaiteGymEnv(env_config=cfg)\n", + "\n", + "client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n", + "client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-send-local-command`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-send-local-command\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: node-send-local-command\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " command:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - dog.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(1)\n", + "client_1.file_system.show(full=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-session-remote-login`` \n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-session-remote-login\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 2:\n", + " action: node-session-remote-login\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " username: admin\n", + " password: admin\n", + " remote_ip: 192.168.10.22 # client_2's ip address.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(2)\n", + "client_2.session_manager.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Terminal Action | ``node-send-remote-command``\n", + "\n", + "The yaml snippet below shows all the relevant agent options for this action:\n", + "\n", + "```yaml\n", + "\n", + " action_space:\n", + " action_list:\n", + " ...\n", + " - type: node-send-remote-command\n", + " ...\n", + " options:\n", + " nodes: # Node List\n", + " - node_name: client_1\n", + " ...\n", + " ...\n", + " action_map:\n", + " 1:\n", + " action: node-send-remote-command\n", + " options:\n", + " node_id: 0 # Index 0 at the node list.\n", + " remote_ip: 192.168.10.22\n", + " commands:\n", + " - file_system\n", + " - create\n", + " - file\n", + " - downloads\n", + " - cat.png\n", + " - False\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(3)\n", + "client_2.file_system.show(full=True)" + ] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index b7a9a042..fa545dbc 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -26,14 +26,26 @@ except ModuleNotFoundError: _LOGGER.debug("Torch not available for importing") -def set_random_seed(seed: int) -> Union[None, int]: +def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]: """ Set random number generators. + If seed is None or -1 and generate_seed_value is True randomly generate a + seed value. + If seed is > -1 and generate_seed_value is True ignore the latter and use + the provide seed value. + :param seed: int + :param generate_seed_value: bool + :return: None or the int representing the seed used. """ if seed is None or seed == -1: - return None + if generate_seed_value: + rng = np.random.default_rng() + # 2**32-1 is highest value for python RNG seed. + seed = int(rng.integers(low=0, high=2**32 - 1)) + else: + return None elif seed < -1: raise ValueError("Invalid random number seed") # Seed python RNG @@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]: return seed +def log_seed_value(seed: int): + """Log the selected seed value to file.""" + path = SIM_OUTPUT.path / "seed.log" + with open(path, "w") as file: + file.write(f"Seed value = {seed}") + + class PrimaiteGymEnv(gymnasium.Env): """ Thin wrapper env to provide agents with a gymnasium API. @@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env): """Object that returns a config corresponding to the current episode.""" self.seed = self.episode_scheduler(0).get("game", {}).get("seed") """Get RNG seed from config file. NB: Must be before game instantiation.""" - self.seed = set_random_seed(self.seed) + self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value") + self.seed = set_random_seed(self.seed, self.generate_seed_value) self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) @@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env): _LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}") + log_seed_value(self.seed) + def action_masks(self) -> np.ndarray: """ Return the action mask for the agent. @@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if seed is not None: - set_random_seed(seed) + set_random_seed(seed, self.generate_seed_value) self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward if self.io.settings.save_agent_actions: diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 8653359a..50a7d2a4 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -864,7 +864,21 @@ class UserManager(Service, discriminator="user-manager"): """ rm = super()._init_request_manager() - # todo add doc about requeest schemas + # todo add doc about request schemas + rm.add_request( + "add_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool( + self.add_user(username=request[0], password=request[1], is_admin=request[2]) + ) + ), + ) + rm.add_request( + "disable_user", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0])) + ), + ) rm.add_request( "change_password", RequestType( @@ -1572,7 +1586,7 @@ class Node(SimComponent, ABC): operating_state: Any = None - users: Any = None # Temporary to appease "extra=forbid" + users: List[Dict] = [] # Temporary to appease "extra=forbid" config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema()) """Configuration items within Node""" @@ -1638,6 +1652,8 @@ class Node(SimComponent, ABC): self._install_system_software() self.session_manager.node = self self.session_manager.software_manager = self.software_manager + for user in self.config.users: + self.user_manager.add_user(**user, bypass_can_perform_action=True) @property def user_manager(self) -> Optional[UserManager]: @@ -1769,7 +1785,7 @@ class Node(SimComponent, ABC): """ application_name = request[0] if self.software_manager.software.get(application_name): - self.sys_log.warning(f"Can't install {application_name}. It's already installed.") + self.sys_log.info(f"Can't install {application_name}. It's already installed.") return RequestResponse(status="success", data={"reason": "already installed"}) application_class = Application._registry[application_name] self.software_manager.install(application_class) diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 76d9167c..86ac790c 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -2,11 +2,12 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, ClassVar, Dict, Literal, Optional +from typing import Any, ClassVar, Dict, List, Literal, Optional from pydantic import Field from primaite import getLogger +from primaite.simulator.file_system.file_type import FileType from primaite.simulator.network.hardware.base import ( IPWiredNetworkInterface, Link, @@ -313,7 +314,7 @@ class HostNode(Node, discriminator="host-node"): """ SYSTEM_SOFTWARE: ClassVar[Dict] = { - "HostARP": HostARP, + "host-arp": HostARP, "icmp": ICMP, "dns-client": DNSClient, "ntp-client": NTPClient, @@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"): ip_address: IPV4Address services: Any = None # temporarily unset to appease extra="forbid" applications: Any = None # temporarily unset to appease extra="forbid" - folders: Any = None # temporarily unset to appease extra="forbid" + folders: List[Dict] = {} # temporarily unset to appease extra="forbid" network_interfaces: Any = None # temporarily unset to appease extra="forbid" config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema()) @@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"): super().__init__(**kwargs) self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask)) + for folder in self.config.folders: + # handle empty foler defined by just a string + self.file_system.create_folder(folder["folder_name"]) + + for file in folder.get("files", []): + self.file_system.create_file( + folder_name=folder["folder_name"], + file_name=file["file_name"], + size=file.get("size", 0), + file_type=FileType[file.get("type", "UNKNOWN").upper()], + ) + @property def nmap(self) -> Optional[NMAP]: """ diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 2cbc23d2..c872b8b3 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"): Example: >>> from primaite.simulator.network.transmission.network_layer import IPProtocol - >>> from primaite.simulator.network.transmission.transport_layer import Port + >>> from primaite.utils.validation.port import Port >>> firewall = Firewall(hostname="Firewall1") >>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0") >>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0") diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 3b35600b..47ee3169 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -467,6 +467,7 @@ class AccessControlList(SimComponent): """Check if a packet with the given properties is permitted through the ACL.""" permitted = False rule: ACLRule = None + for _rule in self._acl: if not _rule: continue @@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"): config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema()) SYSTEM_SOFTWARE: ClassVar[Dict] = { - "UserSessionManager": UserSessionManager, - "UserManager": UserManager, - "Terminal": Terminal, + "user-session-manager": UserSessionManager, + "user-manager": UserManager, + "terminal": Terminal, } network_interfaces: Dict[str, RouterInterface] = {} @@ -1385,6 +1386,12 @@ class Router(NetworkNode, discriminator="router"): return False + def subject_to_acl(self, frame: Frame) -> bool: + """Check that frame is subject to ACL rules.""" + if frame.ip.protocol == "udp" and frame.is_arp: + return False + return True + def receive_frame(self, frame: Frame, from_network_interface: RouterInterface): """ Processes an incoming frame received on one of the router's interfaces. @@ -1398,8 +1405,12 @@ class Router(NetworkNode, discriminator="router"): if self.operating_state != NodeOperatingState.ON: return - # Check if it's permitted - permitted, rule = self.acl.is_permitted(frame) + if self.subject_to_acl(frame=frame): + # Check if it's permitted + permitted, rule = self.acl.is_permitted(frame) + else: + permitted = True + rule = None if not permitted: at_port = self._get_port_of_nic(from_network_interface) diff --git a/src/primaite/simulator/network/transmission/data_link_layer.py b/src/primaite/simulator/network/transmission/data_link_layer.py index e7c2a124..a07194a4 100644 --- a/src/primaite/simulator/network/transmission/data_link_layer.py +++ b/src/primaite/simulator/network/transmission/data_link_layer.py @@ -163,7 +163,7 @@ class Frame(BaseModel): """ Checks if the Frame is an ARP (Address Resolution Protocol) packet. - This is determined by checking if the destination port of the TCP header is equal to the ARP port. + This is determined by checking if the destination and source port of the UDP header is equal to the ARP port. :return: True if the Frame is an ARP packet, otherwise False. """ diff --git a/src/primaite/simulator/system/services/arp/arp.py b/src/primaite/simulator/system/services/arp/arp.py index b0630d5d..c6b687ce 100644 --- a/src/primaite/simulator/system/services/arp/arp.py +++ b/src/primaite/simulator/system/services/arp/arp.py @@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"): :param markdown: If True, format the output as Markdown. Otherwise, use plain text. """ - table = PrettyTable(["IP Address", "MAC Address", "Via"]) + table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"]) if markdown: table.set_style(MARKDOWN) table.align = "l" @@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"): str(ip), arp.mac_address, self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address, + self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num, ] ) print(table) diff --git a/src/primaite/simulator/system/services/terminal/terminal.py b/src/primaite/simulator/system/services/terminal/terminal.py index 2ce7d176..112f6abc 100644 --- a/src/primaite/simulator/system/services/terminal/terminal.py +++ b/src/primaite/simulator/system/services/terminal/terminal.py @@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"): _client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {} """Dictionary of connect requests made to remote nodes.""" + _last_response: Optional[RequestResponse] = None + """Last response received from RequestManager, for returning remote RequestResponse.""" + def __init__(self, **kwargs): kwargs["name"] = "terminal" kwargs["port"] = PORT_LOOKUP["SSH"] kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"] super().__init__(**kwargs) + @property + def last_response(self) -> Optional[RequestResponse]: + """Public version of _last_response attribute.""" + return self._last_response + def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"): return RequestResponse(status="failure", data={}) rm.add_request( - "node-session-remote-login", + "node_session_remote_login", request_type=RequestType(func=_remote_login), ) @@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"): command: str = request[1]["command"] remote_connection = self._get_connection_from_ip(ip_address=ip_address) if remote_connection: - outcome = remote_connection.execute(command) - if outcome: - return RequestResponse( - status="success", - data={}, - ) - else: - return RequestResponse( - status="failure", - data={}, - ) + remote_connection.execute(command) + return self.last_response if not None else RequestResponse(status="failure", data={}) + return RequestResponse( + status="failure", + data={"reason": "Failed to execute command."}, + ) rm.add_request( "send_remote_command", request_type=RequestType(func=remote_execute_request), ) + def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse: + """Executes a command using a local terminal session.""" + command: str = request[2]["command"] + local_connection = self._process_local_login(username=request[0], password=request[1]) + if local_connection: + outcome = local_connection.execute(command) + if outcome: + return RequestResponse( + status="success", + data={"reason": outcome}, + ) + return RequestResponse( + status="success", + data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"}, + ) + + rm.add_request( + "send_local_command", + request_type=RequestType(func=local_execute_request), + ) + return rm def execute(self, command: List[Any]) -> Optional[RequestResponse]: """Execute a passed ssh command via the request manager.""" - return self.parent.apply_request(command) + self._last_response = self.parent.apply_request(command) + return self._last_response def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]: """Find Remote Terminal Connection from a given IP.""" @@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"): """ source_ip = kwargs["frame"].ip.src_ip_address self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}") + self._last_response = None # Clear last response + if isinstance(payload, SSHPacket): if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST: # validate & add connection @@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"): session_id=session_id, source_ip=source_ip, ) + self._last_response: RequestResponse = RequestResponse( + status="success", data={"reason": "Login Successful"} + ) elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST: # Requesting a command to be executed @@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"): payload.connection_uuid ) remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep - self.execute(command) + self._last_response: RequestResponse = self.execute(command) + + if self._last_response.status == "success": + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + else: + transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED + + payload: SSHPacket = SSHPacket( + payload=self._last_response, + transport_message=transport_message, + connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA, + ) + self.software_manager.send_payload_to_session_manager( + payload=payload, dest_port=self.port, session_id=session_id + ) return True else: self.sys_log.error( f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command." ) + elif ( + payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS + or SSHTransportMessage.SSH_MSG_SERVICE_FAILED + ): + # Likely receiving command ack from remote. + self._last_response = payload.payload if isinstance(payload, dict) and payload.get("type"): if payload["type"] == "disconnect": diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index 2eddefc1..3f8760c4 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"): :type: payload: HttpRequestPacket """ response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload) - try: - parsed_url = urlparse(payload.request_url) - path = parsed_url.path.strip("/") - if len(path) < 1: + parsed_url = urlparse(payload.request_url) + path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else "" + + if len(path) < 1: + # query succeeded + response.status_code = HttpStatusCode.OK + + if path.startswith("users"): + # get data from DatabaseServer + # get all users + if not self._establish_db_connection(): + # unable to create a db connection + response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + return response + + if self.db_connection.query("SELECT"): # query succeeded + self.set_health_state(SoftwareHealthState.GOOD) response.status_code = HttpStatusCode.OK + else: + self.set_health_state(SoftwareHealthState.COMPROMISED) + return response - if path.startswith("users"): - # get data from DatabaseServer - # get all users - if not self.db_connection: - self._establish_db_connection() - - if self.db_connection.query("SELECT"): - # query succeeded - self.set_health_state(SoftwareHealthState.GOOD) - response.status_code = HttpStatusCode.OK - else: - self.set_health_state(SoftwareHealthState.COMPROMISED) - - return response - except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) - # something went wrong on the server - response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR - return response - - def _establish_db_connection(self) -> None: + def _establish_db_connection(self) -> bool: """Establish a connection to db.""" + # if active db connection, return true + if self.db_connection: + return True + + # otherwise, try to create db connection db_client = self.software_manager.software.get("database-client") + + if db_client is None: + return False # database client not installed + self.db_connection: DatabaseClientConnection = db_client.get_new_connection() + return self.db_connection is not None def send( self, diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index c9ac5f8d..7ffe4c08 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -25,7 +25,19 @@ game: - ICMP - TCP - UDP - + thresholds: + nmne: + high: 100 + medium: 25 + low: 5 + file_access: + high: 10 + medium: 5 + low: 2 + app_executions: + high: 5 + medium: 3 + low: 2 agents: - ref: client_2_green_user team: GREEN @@ -64,10 +76,16 @@ agents: options: hosts: - hostname: client_1 + applications: + - application_name: WebBrowser + folders: + - folder_name: root + files: + - file_name: "test.txt" - hostname: client_2 - hostname: client_3 num_services: 1 - num_applications: 0 + num_applications: 1 num_folders: 1 num_files: 1 num_nics: 2 @@ -182,6 +200,10 @@ simulation: options: ntp_server_ip: 192.168.1.10 - type: ntp-server + folders: + - folder_name: root + files: + - file_name: test.txt - hostname: client_2 type: computer ip_address: 192.168.10.22 diff --git a/tests/assets/configs/nodes_with_initial_files.yaml b/tests/assets/configs/nodes_with_initial_files.yaml new file mode 100644 index 00000000..d4c6406b --- /dev/null +++ b/tests/assets/configs/nodes_with_initial_files.yaml @@ -0,0 +1,226 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# +io_settings: + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + sys_log_level: WARNING + agent_log_level: INFO + save_agent_logs: true + write_agent_log_to_terminal: True + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: periodic-agent + action_space: + action_map: + 0: + action: do-nothing + options: {} + 1: + action: node-application-execute + options: + node_id: 0 + application_id: 0 + + agent_settings: + possible_start_nodes: [client_2,] + target_application: web-browser + start_step: 5 + frequency: 4 + variance: 3 + + + + - ref: defender + team: BLUE + type: proxy-agent + + observation_space: + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: client_1 + - hostname: client_2 + - hostname: client_3 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + include_nmne: false + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.23 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - switch_1:eth-1<->client_1:eth-1 + - switch_1:eth-2<->client_2:eth-1 + - type: none + label: ICS + options: {} + + action_space: + action_map: + 0: + action: do-nothing + options: {} + + reward_function: + reward_components: + - type: database-file-integrity + weight: 0.5 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + - type: web-server-404-penalty + weight: 0.5 + options: + node_hostname: web_server + service_name: web_server_web_service + + + agent_settings: + flatten_obs: true + +simulation: + network: + nodes: + + - type: switch + hostname: switch_1 + num_ports: 8 + + - hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - type: ransomware-script + - type: web-browser + options: + target_url: http://arcd.com/users/ + - type: database-client + options: + db_server_ip: 192.168.1.10 + server_password: arcd + - type: data-manipulation-bot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.21 + server_password: arcd + - type: dos-bot + options: + target_ip_address: 192.168.10.21 + payload: SPOOF DATA + port_scan_p_of_success: 0.8 + services: + - type: dns-client + options: + dns_server: 192.168.1.10 + - type: dns-server + options: + domain_mapping: + arcd.com: 192.168.1.10 + - type: database-service + options: + backup_server_ip: 192.168.1.10 + - type: web-server + - type: ftp-server + options: + server_password: arcd + - type: ntp-client + options: + ntp_server_ip: 192.168.1.10 + - type: ntp-server + - hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + folders: + - folder_name: empty_folder + - folder_name: downloads + files: + - file_name: "test.txt" + - file_name: "another_file.pwtwoti" + - folder_name: root + files: + - file_name: passwords + size: 663 + type: TXT + # pre installed services and applications + - hostname: client_3 + type: computer + ip_address: 192.168.10.23 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + start_up_duration: 0 + shut_down_duration: 0 + operating_state: "OFF" + # pre installed services and applications + + links: + - endpoint_a_hostname: switch_1 + endpoint_a_port: 1 + endpoint_b_hostname: client_1 + endpoint_b_port: 1 + bandwidth: 200 + - endpoint_a_hostname: switch_1 + endpoint_a_port: 2 + endpoint_b_hostname: client_2 + endpoint_b_port: 1 + bandwidth: 200 diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py index 4153adc0..627fc53b 100644 --- a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from tests import TEST_ASSETS_ROOT -BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" +BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" def load_config(config_path: Union[str, Path]) -> PrimaiteGame: @@ -24,3 +24,42 @@ def test_thresholds(): game = load_config(data_manipulation_config_path()) assert game.options.thresholds is not None + + +def test_nmne_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["nmne"] is not None + + # get NIC observation + nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0] + assert nic_obs.low_nmne_threshold == 5 + assert nic_obs.med_nmne_threshold == 25 + assert nic_obs.high_nmne_threshold == 100 + + +def test_file_access_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["file_access"] is not None + + # get file observation + file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0] + assert file_obs.low_file_access_threshold == 2 + assert file_obs.med_file_access_threshold == 5 + assert file_obs.high_file_access_threshold == 10 + + +def test_app_executions_threshold(): + """Test that the NMNE thresholds are properly loaded in by observation.""" + game = load_config(BASIC_SWITCHED_NETWORK_CONFIG) + + assert game.options.thresholds["app_executions"] is not None + + # get application observation + app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0] + assert app_obs.low_app_execution_threshold == 2 + assert app_obs.med_app_execution_threshold == 3 + assert app_obs.high_app_execution_threshold == 5 diff --git a/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py new file mode 100644 index 00000000..4c99a39f --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_node_file_system_config.py @@ -0,0 +1,64 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from pathlib import Path +from typing import Union + +import yaml + +from primaite.game.game import PrimaiteGame +from primaite.simulator.file_system.file_type import FileType +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_node_file_system_from_config(): + """Test that the appropriate files are instantiated in nodes when loaded from config.""" + game = load_config(BASIC_CONFIG) + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert client_1.software_manager.software.get("database-service") # database service should be installed + assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist + + assert client_1.software_manager.software.get("web-server") # web server should be installed + assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist + + client_2 = game.simulation.network.get_node_by_hostname("client_2") + + # database service should not be installed + assert client_2.software_manager.software.get("database-service") is None + # database files should not exist + assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None + + # web server should not be installed + assert client_2.software_manager.software.get("web-server") is None + # web files should not exist + assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None + + empty_folder = client_2.file_system.get_folder(folder_name="empty_folder") + assert empty_folder + assert len(empty_folder.files) == 0 # should have no files + + password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt") + assert password_file # should exist + assert password_file.file_type is FileType.TXT + assert password_file.size == 663 + + downloads_folder = client_2.file_system.get_folder(folder_name="downloads") + assert downloads_folder # downloads folder should exist + + test_txt = downloads_folder.get_file(file_name="test.txt") + assert test_txt # test.txt should exist + assert test_txt.file_type is FileType.TXT + + unknown_file_type = downloads_folder.get_file(file_name="another_file.pwtwoti") + assert unknown_file_type # unknown_file_type should exist + assert unknown_file_type.file_type is FileType.UNKNOWN diff --git a/tests/integration_tests/extensions/nodes/giga_switch.py b/tests/integration_tests/extensions/nodes/giga_switch.py index d9599618..5c202ed2 100644 --- a/tests/integration_tests/extensions/nodes/giga_switch.py +++ b/tests/integration_tests/extensions/nodes/giga_switch.py @@ -49,7 +49,7 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"): if markdown: table.set_style(MARKDOWN) table.align = "l" - table.title = f"{self.hostname} Switch Ports" + table.title = f"{self.config.hostname} Switch Ports" for port_num, port in self.network_interface.items(): table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"]) print(table) diff --git a/tests/integration_tests/game_layer/actions/test_terminal_actions.py b/tests/integration_tests/game_layer/actions/test_terminal_actions.py index c39d8263..3ee97fb7 100644 --- a/tests/integration_tests/game_layer/actions/test_terminal_actions.py +++ b/tests/integration_tests/game_layer/actions/test_terminal_actions.py @@ -106,7 +106,6 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame "username": "user123", "current_password": "password", "new_password": "different_password", - "remote_ip": str(server_1.network_interface[1].ip_address), }, ) agent.store_action(action) @@ -146,7 +145,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam "username": "user123", "current_password": "password", "new_password": "different_password", - "remote_ip": str(server_1.network_interface[1].ip_address), }, ) agent.store_action(action) @@ -166,3 +164,55 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam assert server_1.file_system.get_folder("folder123") is None assert server_1.file_system.get_file("folder123", "doggo.pdf") is None + + +def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]): + game, agent = game_and_agent_fixture + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + # create a new user account on server_1 that will be logged into remotely + client_1_usm: UserManager = client_1.software_manager.software["user-manager"] + client_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-send-local-command", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "doggo.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_folder("folder123") + assert client_1.file_system.get_file("folder123", "doggo.pdf") + + # Change password + action = ( + "node-account-change-password", + { + "node_name": "client_1", + "username": "user123", + "current_password": "password", + "new_password": "different_password", + }, + ) + agent.store_action(action) + game.step() + + action = ( + "node-send-local-command", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "command": ["file_system", "create", "file", "folder123", "cat.pdf", False], + }, + ) + agent.store_action(action) + game.step() + + assert client_1.file_system.get_file("folder123", "cat.pdf") is None + client_1.session_manager.show() diff --git a/tests/integration_tests/game_layer/actions/test_user_account_actions.py b/tests/integration_tests/game_layer/actions/test_user_account_actions.py new file mode 100644 index 00000000..26b871db --- /dev/null +++ b/tests/integration_tests/game_layer/actions/test_user_account_actions.py @@ -0,0 +1,176 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import pytest + +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.utils.validation.port import Port, PORT_LOOKUP + + +@pytest.fixture +def game_and_agent_fixture(game_and_agent): + """Create a game with a simple agent that can be controlled by the tests.""" + game, agent = game_and_agent + + client_1: Computer = game.simulation.network.get_node_by_hostname("client_1") + client_1.start_up_duration = 3 + + return (game, agent) + + +def test_user_account_add_user_action(game_and_agent_fixture): + """Tests the add user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + assert len(client_1.user_manager.users) == 1 # admin is created by default + assert len(client_1.user_manager.admins) == 1 + + # add admin account + action = ( + "node-account-add-user", + {"node_name": "client_1", "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + # add non admin account + action = ( + "node-account-add-user", + {"node_name": "client_1", "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False}, + ) + agent.store_action(action) + game.step() + + assert len(client_1.user_manager.users) == 3 # new user added + assert len(client_1.user_manager.admins) == 2 + + +def test_user_account_disable_user_action(game_and_agent_fixture): + """Tests the disable user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="password", is_admin=True) + assert len(client_1.user_manager.users) == 2 # new user added + assert len(client_1.user_manager.admins) == 2 + + test_user = client_1.user_manager.users.get("test") + assert test_user + assert test_user.disabled is not True + + # disable test account + action = ( + "node-account-disable-user", + { + "node_name": "client_1", + "username": "test", + }, + ) + agent.store_action(action) + game.step() + assert test_user.disabled + + +def test_user_account_change_password_action(game_and_agent_fixture): + """Tests the change password user account action.""" + game, agent = game_and_agent_fixture + client_1 = game.simulation.network.get_node_by_hostname("client_1") + + client_1.user_manager.add_user(username="test", password="password", is_admin=True) + + test_user = client_1.user_manager.users.get("test") + assert test_user.password == "password" + + # change account password + action = ( + "node-account-change-password", + {"node_name": "client_1", "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"}, + ) + agent.store_action(action) + game.step() + + assert test_user.password == "2Hard_2_Hack" + + +def test_user_account_create_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to create new users.""" + game, agent = game_and_agent_fixture + + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["user-manager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-session-remote-login", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Create a new user account via terminal. + action = ( + "node-send-remote-command", + { + "node_name": "client_1", + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "user-manager", "add_user", "new_user", "new_pass", True], + }, + ) + agent.store_action(action) + game.step() + new_user = server_1.user_manager.users.get("new_user") + assert new_user + assert new_user.password == "new_pass" + assert new_user.disabled is not True + + +def test_user_account_disable_terminal_action(game_and_agent_fixture): + """Tests that agents can use the terminal to disable users.""" + game, agent = game_and_agent_fixture + router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4) + + server_1 = game.simulation.network.get_node_by_hostname("server_1") + server_1_usm = server_1.software_manager.software["user-manager"] + server_1_usm.add_user("user123", "password", is_admin=True) + + action = ( + "node-session-remote-login", + { + "node_name": "client_1", + "username": "user123", + "password": "password", + "remote_ip": str(server_1.network_interface[1].ip_address), + }, + ) + agent.store_action(action) + game.step() + assert agent.history[-1].response.status == "success" + + # Disable a user via terminal + action = ( + "node-send-remote-command", + { + "node_name": "client_1", + "remote_ip": str(server_1.network_interface[1].ip_address), + "command": ["service", "user-manager", "disable_user", "user123"], + }, + ) + agent.store_action(action) + game.step() + + new_user = server_1.user_manager.users.get("user123") + assert new_user + assert new_user.disabled is True diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 7323461c..722fd294 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -44,6 +44,38 @@ def test_file_observation(simulation): assert observation_state.get("health_status") == 3 # corrupted +def test_config_file_access_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + file_obs = FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert file_obs.high_file_access_threshold == 9 + assert file_obs.med_file_access_threshold == 6 + assert file_obs.low_file_access_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + FileObservation( + where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, + file_system_requires_scan=True, + thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}}, + ) + + def test_folder_observation(simulation): """Test the folder observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index b5e5ca81..30eccb06 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -77,6 +77,14 @@ def test_nic(simulation): nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True) + # The Simulation object created by the fixture also creates the + # NICObservation class with the NICObservation.capture_nmnme class variable + # set to False. Under normal (non-test) circumstances this class variable + # is set from a config file such as data_manipulation.yaml. So although + # capture_nmne is set to True in the NetworkInterface class it's still False + # in the NICObservation class so we set it now. + nic_obs.capture_nmne = True + # Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs nmne_config = { "capture_nmne": True, # Enable the capture of MNEs @@ -115,14 +123,11 @@ def test_nic_categories(simulation): assert nic_obs.low_nmne_threshold == 0 # default -@pytest.mark.skip(reason="Feature not implemented yet") def test_config_nic_categories(simulation): pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=6, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}}, include_nmne=True, ) @@ -133,20 +138,16 @@ def test_config_nic_categories(simulation): with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=9, - med_nmne_threshold=6, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}}, include_nmne=True, ) with pytest.raises(Exception): # should throw an error NICObservation( - where=["network", "nodes", pc.hostname, "NICs", 1], - low_nmne_threshold=3, - med_nmne_threshold=9, - high_nmne_threshold=9, + where=["network", "nodes", pc.config.hostname, "NICs", 1], + thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}}, include_nmne=True, ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 09eb3fe4..aef60bc2 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -39,6 +39,8 @@ def test_host_observation(simulation): folders=[], network_interfaces=[], file_system_requires_scan=True, + services_requires_scan=True, + applications_requires_scan=True, include_users=False, ) diff --git a/tests/integration_tests/game_layer/observations/test_obs_data_capture.py b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py new file mode 100644 index 00000000..e8bdea22 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_obs_data_capture.py @@ -0,0 +1,28 @@ +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +import json + +from primaite.session.environment import PrimaiteGymEnv +from primaite.session.io import PrimaiteIO +from tests import TEST_ASSETS_ROOT + +DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml" + + +def test_obs_data_in_log_file(): + """Create a log file of AgentHistoryItems and check observation data is + included. Assumes that data_manipulation.yaml has an agent labelled + 'defender' with a non-null observation space. + The log file will be in: + primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions + """ + env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG) + env.reset() + for _ in range(10): + env.step(0) + env.reset() + io = PrimaiteIO() + path = io.generate_agent_actions_save_path(episode=1) + with open(path, "r") as f: + j = json.load(f) + + assert type(j["0"]["defender"]["observation"]) == dict diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py index 28cdaf01..1ebff10c 100644 --- a/tests/integration_tests/game_layer/observations/test_software_observations.py +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -29,7 +29,9 @@ def test_service_observation(simulation): ntp_server = pc.software_manager.software.get("ntp-server") assert ntp_server - service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"]) + service_obs = ServiceObservation( + where=["network", "nodes", pc.config.hostname, "services", "ntp-server"], services_requires_scan=True + ) assert service_obs.space["operating_status"] == spaces.Discrete(7) assert service_obs.space["health_status"] == spaces.Discrete(5) @@ -54,7 +56,9 @@ def test_application_observation(simulation): web_browser: WebBrowser = pc.software_manager.software.get("web-browser") assert web_browser - app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"]) + app_obs = ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "web-browser"], applications_requires_scan=True + ) web_browser.close() observation_state = app_obs.observe(simulation.describe_state()) @@ -69,3 +73,33 @@ def test_application_observation(simulation): assert observation_state.get("health_status") == 1 assert observation_state.get("operating_status") == 1 # running assert observation_state.get("num_executions") == 1 + + +def test_application_executions_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + app_obs = ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}}, + ) + + assert app_obs.high_app_execution_threshold == 9 + assert app_obs.med_app_execution_threshold == 6 + assert app_obs.low_app_execution_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}}, + ) + + with pytest.raises(Exception): + # should throw an error + ApplicationObservation( + where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"], + applications_requires_scan=False, + thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}}, + ) diff --git a/tests/integration_tests/game_layer/test_RNG_seed.py b/tests/integration_tests/game_layer/test_RNG_seed.py index 45fa445d..2b80e153 100644 --- a/tests/integration_tests/game_layer/test_RNG_seed.py +++ b/tests/integration_tests/game_layer/test_RNG_seed.py @@ -7,6 +7,7 @@ import yaml from primaite.config.load import data_manipulation_config_path from primaite.game.agent.interface import AgentHistoryItem from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator import SIM_OUTPUT @pytest.fixture() @@ -33,6 +34,11 @@ def test_rng_seed_set(create_env): assert a == b + # Check that seed log file was created. + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + assert file + def test_rng_seed_unset(create_env): """Test with no RNG seed.""" @@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env): b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"] assert a != b + + +def test_for_generated_seed(): + """ + Show that setting generate_seed_value to true producess a valid seed. + """ + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + cfg["game"]["generate_seed_value"] = True + PrimaiteGymEnv(env_config=cfg) + path = SIM_OUTPUT.path / "seed.log" + with open(path, "r") as file: + data = file.read() + + assert data.split(" ")[3] != None diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 03b94ab7..59bee385 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -22,6 +22,7 @@ from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.system.applications.application import ApplicationOperatingState from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.software import SoftwareHealthState @@ -107,7 +108,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox """ Test that the RouterACLAddRuleAction can form a request and that it is accepted by the simulation. - The acl starts off with 4 rules, and we add a rule, and check that the acl now has 5 rules. + The ACL starts off with 4 rules, and we add a rule, and check that the ACL now has 5 rules. """ game, agent = game_and_agent @@ -164,11 +165,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox }, ) agent.store_action(action) - print(agent.most_recent_action) game.step() - print(agent.most_recent_action) + # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 - print(router.acl.show()) assert router.acl.num_rules == 6 assert server_1.ping("10.0.2.3") # Can ping server_2 @@ -180,7 +179,8 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P # 1: Check that http traffic is going across the network nicely. client_1 = game.simulation.network.get_node_by_hostname("client_1") server_1 = game.simulation.network.get_node_by_hostname("server_1") - router = game.simulation.network.get_node_by_hostname("router") + router: Router = game.simulation.network.get_node_by_hostname("router") + assert router.acl.num_rules == 4 browser: WebBrowser = client_1.software_manager.software.get("web-browser") browser.run() diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index ea7fbc99..80e7c3b3 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,5 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +from itertools import product + +import yaml + +from primaite.config.load import data_manipulation_config_path from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.session.environment import PrimaiteGymEnv from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server @@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network): assert web_nic_obs["outbound"] == expected_nmne assert db_nic_obs["inbound"] == expected_nmne uc2_network.apply_timestep(timestep=0) + + +def test_nmne_parameter_settings(): + """ + Check that the four permutations of the values of capture_nmne and + include_nmne work as expected. + """ + + with open(data_manipulation_config_path(), "r") as f: + cfg = yaml.safe_load(f) + + DEFENDER = 3 + for capture, include in product([True, False], [True, False]): + cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture + cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include + PrimaiteGymEnv(env_config=cfg) diff --git a/tests/integration_tests/system/test_arp.py b/tests/integration_tests/system/test_arp.py index 055d58c6..b9a92255 100644 --- a/tests/integration_tests/system/test_arp.py +++ b/tests/integration_tests/system/test_arp.py @@ -1,6 +1,7 @@ -# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK -from primaite.simulator.network.hardware.nodes.network.router import RouterARP +# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP from primaite.simulator.system.services.arp.arp import ARP +from primaite.utils.validation.port import PORT_LOOKUP from tests.integration_tests.network.test_routing import multi_hop_network @@ -48,3 +49,19 @@ def test_arp_fails_for_network_address_between_routers(multi_hop_network): actual_result = router_1_arp.get_arp_cache_mac_address(router_1.network_interface[1].ip_network.network_address) assert actual_result == expected_result + + +def test_arp_not_affected_by_acl(multi_hop_network): + pc_a = multi_hop_network.get_node_by_hostname("pc_a") + router_1: Router = multi_hop_network.get_node_by_hostname("router_1") + + # Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic + # as it operates a different layer within the network. + router_1.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=23) + + pc_a_arp: ARP = pc_a.software_manager.arp + + expected_result = router_1.network_interface[2].mac_address + actual_result = pc_a_arp.get_arp_cache_mac_address(router_1.network_interface[2].ip_address) + + assert actual_result == expected_result diff --git a/tests/unit_tests/_primaite/_game/_agent/test_observations.py b/tests/unit_tests/_primaite/_game/_agent/test_observations.py index 5d5921a9..3df6ca0a 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_observations.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_observations.py @@ -1,10 +1,11 @@ # © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK +import json from typing import List import pytest import yaml -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations import ApplicationObservation, ObservationManager, ServiceObservation from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation from primaite.game.agent.observations.host_observations import HostObservation @@ -136,3 +137,227 @@ class TestFileSystemRequiresScan: [], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True ) assert obs_requiring_scan.observe(folder_state)["health_status"] == 1 + + +class TestServicesRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("services_requires_scan: true", True), + ("services_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set service_requires_scan to True.""" + obs_cfg_yaml = f""" + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: web-server + - service_name: dns-client + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + services: + - service_name: ftp-server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 3 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: none + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure services require scan by default + + def test_services_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1} + + obs_requiring_scan = ServiceObservation([], services_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value + + +class TestApplicationsRequiresScan: + @pytest.mark.parametrize( + ("yaml_option_string", "expected_val"), + ( + ("applications_requires_scan: true", True), + ("applications_requires_scan: false", False), + (" ", True), + ), + ) + def test_obs_config(self, yaml_option_string, expected_val): + """Check that the default behaviour is to set applications_requires_scan to True.""" + obs_cfg_yaml = f""" + type: custom + options: + components: + - type: nodes + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + applications: + - application_name: web-browser + - hostname: client_2 + applications: + - application_name: web-browser + - application_name: database-client + num_services: 0 + num_applications: 3 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + {yaml_option_string} + include_nmne: true + monitored_traffic: + icmp: + - NONE + tcp: + - DNS + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: links + label: LINKS + options: + link_references: + - router_1:eth-1<->switch_1:eth-8 + - router_1:eth-2<->switch_2:eth-8 + - switch_1:eth-1<->domain_controller:eth-1 + - switch_1:eth-2<->web_server:eth-1 + - switch_1:eth-3<->database_server:eth-1 + - switch_1:eth-4<->backup_server:eth-1 + - switch_1:eth-7<->security_suite:eth-1 + - switch_2:eth-1<->client_1:eth-1 + - switch_2:eth-2<->client_2:eth-1 + - switch_2:eth-7<->security_suite:eth-2 + - type: none + label: ICS + options: {{}} + + """ + + cfg = yaml.safe_load(obs_cfg_yaml) + manager = ObservationManager.from_config(cfg) + + hosts: List[HostObservation] = manager.obs.components["NODES"].hosts + for i, host in enumerate(hosts): + services: List[ServiceObservation] = host.services + for j, service in enumerate(services): + val = service.services_requires_scan + print(f"host {i} service {j} {val}") + assert val == expected_val # Make sure applications require scan by default + + def test_applications_requires_scan(self): + state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1, "num_executions": 1} + + obs_requiring_scan = ApplicationObservation([], applications_requires_scan=True) + assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value + + obs_not_requiring_scan = ApplicationObservation([], applications_requires_scan=False) + assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py index 91369f6c..81e05467 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_ftp_client.py @@ -73,7 +73,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client): assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR -def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client): +def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client): """Method send_file should return false if no file to send.""" assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None diff --git a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py index 32fcae9a..3b2377e9 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_services/test_terminal.py @@ -6,6 +6,7 @@ import pytest from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.interface.request import RequestResponse from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server @@ -442,3 +443,59 @@ def test_terminal_connection_timeout(basic_network): assert len(computer_b.user_session_manager.remote_sessions) == 0 assert not remote_connection.is_active + + +def test_terminal_last_response_updates(basic_network): + """Test that the _last_response within Terminal correctly updates.""" + network: Network = basic_network + computer_a: Computer = network.get_node_by_hostname("node_a") + terminal_a: Terminal = computer_a.software_manager.software.get("terminal") + computer_b: Computer = network.get_node_by_hostname("node_b") + + assert terminal_a.last_response is None + + remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11") + + # Last response should be a successful logon + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"}) + + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) + + # Last response should now update following successful install + assert terminal_a.last_response == RequestResponse(status="success", data={}) + + remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"]) + + # Last response should now update to success, but with supplied reason. + assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"}) + + remote_connection.execute(command=["file_system", "create", "file", "folder123", "doggo.pdf", False]) + + # Check file was created. + assert computer_b.file_system.access_file(folder_name="folder123", file_name="doggo.pdf") + + # Last response should be confirmation of file creation. + assert terminal_a.last_response == RequestResponse( + status="success", + data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400}, + ) + + remote_connection.execute( + command=[ + "service", + "ftp-client", + "send", + { + "dest_ip_address": "192.168.0.2", + "src_folder": "folder123", + "src_file_name": "cat.pdf", + "dest_folder": "root", + "dest_file_name": "cat.pdf", + }, + ] + ) + + assert terminal_a.last_response == RequestResponse( + status="failure", + data={"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"}, + ) From b4b0f99c23baaefc4cf9750cad1f28c8edba23c8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 26 Feb 2025 17:57:23 +0000 Subject: [PATCH 66/72] Fix mismerge of agent show_history method --- src/primaite/game/agent/interface.py | 57 +++++----------------------- 1 file changed, 9 insertions(+), 48 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 1fef14ef..a6e9739f 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -107,66 +107,27 @@ class AbstractAgent(BaseModel, ABC): self.reward_function = RewardFunction(config=self.config.reward_function) return super().model_post_init(__context) - def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable: - """Update the given table with information from given AgentHistoryItem.""" - node, application = "unknown", "unknown" - if (node_id := item.parameters.get("node_id")) is not None: - node = self.action_manager.node_names[node_id] - if (application_id := item.parameters.get("application_id")) is not None: - application = self.action_manager.application_names[node_id][application_id] - if (application_name := item.parameters.get("application_name")) is not None: - application = application_name - table.add_row([item.timestep, item.action, node, application, item.response.status]) - return table - def show_history(self, ignored_actions: Optional[list] = None): """ - Print an agent action provided it's not the DONOTHING action. + Print an agent action provided it's not the do-nothing action. :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. - If not provided, defaults to ignore DONOTHING actions. + If not provided, defaults to ignore do-nothing actions. """ if not ignored_actions: - ignored_actions = ["DONOTHING"] + ignored_actions = ["do-nothing"] table = PrettyTable() - table.field_names = ["Step", "Action", "Node", "Application", "Response"] - print(f"Actions for '{self.agent_name}':") + table.field_names = ["Step", "Action", "Params", "Response", "Response Data"] + print(f"Actions for '{self.config.ref}':") for item in self.history: if item.action in ignored_actions: pass else: - table = self.add_agent_action(item=item, table=table) - print(table) + # format dict by putting each key-value entry on a separate line and putting a blank line on the end. + param_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.parameters.items()], ""]) + data_string = "\n".join([*[f"{k}: {v:.30}" for k, v in item.response.data], ""]) - def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable: - """Update the given table with information from given AgentHistoryItem.""" - node, application = "unknown", "unknown" - if (node_id := item.parameters.get("node_id")) is not None: - node = self.action_manager.node_names[node_id] - if (application_id := item.parameters.get("application_id")) is not None: - application = self.action_manager.application_names[node_id][application_id] - if (application_name := item.parameters.get("application_name")) is not None: - application = application_name - table.add_row([item.timestep, item.action, node, application, item.response.status]) - return table - - def show_history(self, ignored_actions: Optional[list] = None): - """ - Print an agent action provided it's not the DONOTHING action. - - :param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history. - If not provided, defaults to ignore DONOTHING actions. - """ - if not ignored_actions: - ignored_actions = ["DONOTHING"] - table = PrettyTable() - table.field_names = ["Step", "Action", "Node", "Application", "Response"] - print(f"Actions for '{self.agent_name}':") - for item in self.history: - if item.action in ignored_actions: - pass - else: - table = self.add_agent_action(item=item, table=table) + table.add_row([item.timestep, item.action, param_string, item.response.status, data_string]) print(table) def update_observation(self, state: Dict) -> ObsType: From 8c399c4f61052c7bd55e0d330ee0e3542e61a0d5 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 26 Feb 2025 18:11:42 +0000 Subject: [PATCH 67/72] Fix mismerge of c2 e2e notebook --- ...ommand-and-Control-E2E-Demonstration.ipynb | 157 +++++------------- 1 file changed, 39 insertions(+), 118 deletions(-) diff --git a/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb index 882c3429..f187c8d5 100644 --- a/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-and-Control-E2E-Demonstration.ipynb @@ -59,7 +59,7 @@ "custom_c2_agent = \"\"\"\n", " - ref: CustomC2Agent\n", " team: RED\n", - " type: ProxyAgent\n", + " type: proxy-agent\n", "\n", " action_space:\n", " action_map:\n", @@ -74,12 +74,8 @@ " 2:\n", " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", - " config:\n", - " c2_server_ip_address: 192.168.10.21\n", - " keep_alive_frequency:\n", - " masquerade_protocol:\n", - " masquerade_port:\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.21\n", " 3:\n", " action: node-application-execute\n", " options:\n", @@ -101,10 +97,9 @@ " 5:\n", " action: c2-server-ransomware-configure\n", " options:\n", - " node_id: 1\n", - " config:\n", - " server_ip_address: 192.168.1.14\n", - " payload: ENCRYPT\n", + " node_name: client_1\n", + " server_ip_address: 192.168.1.14\n", + " payload: ENCRYPT\n", " 6:\n", " action: c2-server-data-exfiltrate\n", " options:\n", @@ -123,25 +118,20 @@ " 8:\n", " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", - " config:\n", - " c2_server_ip_address: 192.168.10.21\n", - " keep_alive_frequency: 10\n", - " masquerade_protocol: TCP\n", - " masquerade_port: DNS\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.21\n", + " keep_alive_frequency: 10\n", + " masquerade_protocol: tcp\n", + " masquerade_port: dns\n", " 9:\n", " action: configure-c2-beacon\n", " options:\n", - " node_id: 0\n", - " config:\n", - " c2_server_ip_address: 192.168.10.22\n", - " keep_alive_frequency:\n", - " masquerade_protocol:\n", - " masquerade_port:\n", + " node_name: web_server\n", + " c2_server_ip_address: 192.168.10.22\n", "\n", " reward_function:\n", " reward_components:\n", - " - type: DUMMY\n", + " - type: dummy\n", "\"\"\"\n", "c2_agent_yaml = yaml.safe_load(custom_c2_agent)" ] @@ -287,13 +277,6 @@ "\n", "```yaml\n", " action_space:\n", - " options:\n", - " nodes: # Node List\n", - " - node_name: web_server\n", - " applications: \n", - " - application_name: C2Beacon\n", - " ...\n", - " ...\n", " action_map:\n", " 3:\n", " action: node-application-execute\n", @@ -352,13 +335,6 @@ "\n", "``` yaml\n", " action_space:\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 4:\n", " action: c2-server-terminal-command\n", @@ -408,13 +384,6 @@ "\n", "``` yaml\n", " action_space:\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 5:\n", " action: c2-server-ransomware-configure\n", @@ -459,13 +428,6 @@ "\n", "``` yaml\n", " action_space:\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 6:\n", " action: c2-server-data-exfiltrate\n", @@ -524,13 +486,6 @@ "\n", "``` yaml\n", " action_space:\n", - " options:\n", - " nodes: # Node List\n", - " ...\n", - " - node_name: client_1\n", - " applications: \n", - " - application_name: C2Server\n", - " ...\n", " action_map:\n", " 7:\n", " action: c2-server-ransomware-launch\n", @@ -584,8 +539,8 @@ " type: custom\n", " options:\n", " components:\n", - " - type: NODES\n", - " label: NODES\n", + " - type: nodes\n", + " label: nodes\n", " options:\n", " hosts:\n", " - hostname: web_server\n", @@ -667,55 +622,29 @@ " 1:\n", " action: node-application-remove\n", " options:\n", - " node_id: 0\n", - " application_name: C2Beacon\n", + " node_name: web_server\n", + " application_name: c2-beacon\n", " 2:\n", " action: node-shutdown\n", " options:\n", - " node_id: 0\n", + " node_name: web_server\n", " 3:\n", " action: router-acl-add-rule\n", " options:\n", " target_router: router_1\n", " position: 1\n", - " permission: 2\n", - " source_ip_id: 2\n", - " dest_ip_id: 3\n", - " source_port_id: 2\n", - " dest_port_id: 2\n", - " protocol_id: 1\n", - " source_wildcard_id: 0\n", - " dest_wildcard_id: 0\n", + " permission: DENY\n", + " src_ip: 192.168.10.21\n", + " dst_ip: 192.168.1.12\n", + " src_port: HTTP\n", + " dst_port: HTTP\n", + " protocol_name: ALL\n", + " src_wildcard: 0.0.0.1\n", + " dst_wildcard: 0.0.0.1\n", "\n", - "\n", - " options:\n", - " nodes:\n", - " - node_name: web_server\n", - " applications:\n", - " - application_name: C2Beacon\n", - "\n", - " - node_name: database_server\n", - " folders:\n", - " - folder_name: database\n", - " files:\n", - " - file_name: database.db\n", - " services:\n", - " - service_name: DatabaseService\n", - " - node_name: router_1\n", - "\n", - " max_folders_per_node: 2\n", - " max_files_per_folder: 2\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.10.21\n", - " - 192.168.1.12\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " reward_function:\n", " reward_components:\n", - " - type: DUMMY\n", + " - type: dummy\n", "\n", " agent_settings:\n", " flatten_obs: False\n", @@ -1112,7 +1041,7 @@ "outputs": [], "source": [ "# Attempting to install the C2 RansomwareScript\n", - "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n", + "ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", @@ -1200,7 +1129,7 @@ "outputs": [], "source": [ "# Attempting to install the C2 RansomwareScript\n", - "ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"ransomware-script\"],\n", + "ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n", " \"username\": \"admin\",\n", " \"password\": \"admin\"}\n", "\n", @@ -1325,7 +1254,7 @@ "metadata": {}, "outputs": [], "source": [ - "database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database-server\")\n", + "database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database_server\")\n", "database_server.software_manager.file_system.show(full=True)" ] }, @@ -1369,12 +1298,14 @@ "source": [ "As demonstrated earlier, red agents can use the ``configure-c2-beacon`` action to configure these settings mid episode through the configuration options:\n", "\n", - "``` YAML\n", - "...\n", - " action: configure_c2_beacon\n", - " options:\n", - " node_id: 0\n", - " config:\n", + "```YAML\n", + "\n", + " action_space:\n", + " action_map:\n", + " 8:\n", + " action: configure-c2-beacon\n", + " options:\n", + " node_name: web_server\n", " c2_server_ip_address: 192.168.10.21\n", " keep_alive_frequency: 10\n", " masquerade_protocol: tcp\n", @@ -1739,16 +1670,6 @@ "\n", "display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "env.game.agents[\"CustomC2Agent\"].show_history()" - ] } ], "metadata": { From cf33dcdcf9e9948bcd9c22331e304fafff15088e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 26 Feb 2025 18:12:20 +0000 Subject: [PATCH 68/72] remove outdated information from agents doc page --- docs/source/configuration/agents.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst index ee84aede..c2674e31 100644 --- a/docs/source/configuration/agents.rst +++ b/docs/source/configuration/agents.rst @@ -20,12 +20,6 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo - ref: green_agent_example team: GREEN type: probabilistic-agent - observation_space: - type: UC2GreenObservation # TODO: what - action_space: - reward_function: - reward_components: - - type: dummy agent_settings: start_settings: From f1a36cafaac887573f8b7d9be6d552d0383a4951 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 26 Feb 2025 18:13:45 +0000 Subject: [PATCH 69/72] remove outdated information from data manipulation bot doc page --- .../applications/data_manipulation_bot.rst | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 3ddb8bca..04c581bd 100644 --- a/docs/source/simulation_components/system/applications/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -97,26 +97,6 @@ If not using the data manipulation bot manually, it needs to be used with a data team: RED type: red-database-corrupting-agent - observation_space: - type: uc2-red-observation #TODO what - options: - nodes: - - node_name: client_1 - observations: - - logon_status - - operating_status - applications: - - application_ref: data_manipulation_bot - observations: - operating_status - health_status - folders: {} - - action_space: - reward_function: - reward_components: - - type: dummy - agent_settings: start_settings: start_step: 25 From fd367d1f0eff88095612cfefcdf7a20f559b395d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 26 Feb 2025 18:21:28 +0000 Subject: [PATCH 70/72] Fix typos and duplicate identifiers in docs --- docs/source/request_system.rst | 2 -- .../system/applications/database_client.rst | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index 30ced50a..93fc2a9f 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -4,8 +4,6 @@ .. _request_system: -.. _request_system: - Request System ************** diff --git a/docs/source/simulation_components/system/applications/database_client.rst b/docs/source/simulation_components/system/applications/database_client.rst index 472b504c..465827d9 100644 --- a/docs/source/simulation_components/system/applications/database_client.rst +++ b/docs/source/simulation_components/system/applications/database_client.rst @@ -59,7 +59,7 @@ Python # install DatabaseClient client.software_manager.install(DatabaseClient) - database_client: DatabaseClient = client.software_manager.software.get("database-sclient") + database_client: DatabaseClient = client.software_manager.software.get("database-client") # Configure the DatabaseClient database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService From bab40603788f2f06fa6945527238217b79e54743 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 26 Feb 2025 18:26:54 +0000 Subject: [PATCH 71/72] Fix agent config in terminal processing notebook --- src/primaite/notebooks/Terminal-Processing.ipynb | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/primaite/notebooks/Terminal-Processing.ipynb b/src/primaite/notebooks/Terminal-Processing.ipynb index 07d38791..755b0184 100644 --- a/src/primaite/notebooks/Terminal-Processing.ipynb +++ b/src/primaite/notebooks/Terminal-Processing.ipynb @@ -298,21 +298,7 @@ " - ref: CustomC2Agent\n", " team: RED\n", " type: proxy-agent\n", - " observation_space: null\n", " action_space:\n", - " options:\n", - " nodes:\n", - " - node_name: client_1\n", - " max_folders_per_node: 1\n", - " max_files_per_folder: 1\n", - " max_services_per_node: 2\n", - " max_nics_per_node: 8\n", - " max_acl_rules: 10\n", - " ip_list:\n", - " - 192.168.1.21\n", - " - 192.168.1.14\n", - " wildcard_list:\n", - " - 0.0.0.1\n", " action_map:\n", " 0:\n", " action: do-nothing\n", @@ -508,7 +494,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, From 2b04695c2e8e8d0ca2af074c6eb45f41ee78d63f Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 27 Feb 2025 10:07:17 +0000 Subject: [PATCH 72/72] Apply suggestions from code review --- src/primaite/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 21a8fb4d..d9b058f1 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -4.0.0a1-dev +4.0.0-dev