Merge branch 'dev' into feature/2689-command-and-control

This commit is contained in:
Archer Bowen
2024-08-07 14:18:40 +01:00
58 changed files with 3158 additions and 534 deletions

View File

@@ -102,9 +102,7 @@ stages:
version: '2.1.x'
- script: |
coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml --cov-fail-under=80
coverage xml -o coverage.xml -i
coverage html -d htmlcov -i
python run_test_and_coverage.py
displayName: 'Run tests and code coverage'
# Run the notebooks

View File

@@ -7,255 +7,179 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- **show_bandwidth_load Function**: Displays current bandwidth load for each frequency in the airspace.
- **Bandwidth Tracking**: Tracks data transmission across each frequency.
- **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files.
- **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file.
- Random Number Generator Seeding by specifying a random number seed in the config file.
- Implemented Terminal service class, providing a generic terminal simulation.
### Changed
- **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`.
- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes.
- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
### Fixed
- **Transmission Permission Logic**: Corrected the logic in `can_transmit_frame` to accurately prevent overloads by checking if the transmission of a frame stays within allowable bandwidth limits after considering current load.
- Removed the install/uninstall methods in the node class and made the software manager install/uninstall handle all of their functionality.
[//]: # (This file needs tidying up between 2.0.0 and this line as it hasn't been segmented into 3.0.0 and 3.1.0 and isn't compliant with https://keepachangelog.com/en/1.1.0/)
## 3.0.0b9
- Removed deprecated `PrimaiteSession` class.
- Added ability to set log levels via configuration.
- Upgraded pydantic to version 2.7.0
- Upgraded Ray to version >= 2.9
- Added ipywidgets to the dependencies
- Added ability to define scenarios that change depending on the episode number.
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
- Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient
- Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's.
- Implemented the uninstall functionality for DatabaseClient so that all connections are terminated at the DatabaseService.
- Added the ability for a DatabaseService to terminate a connection.
- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used.
- Added additional show functions to enable connection inspection.
- Updates to agent logging, to include the reward both per step and per episode.
- Introduced Developer CLI tools to assist with developing/debugging PrimAITE
- Can be enabled via `primaite dev-mode enable`
- Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located
- Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization.
- Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training.
## [Unreleased]
- Made requests fail to reach their target if the node is off
- Added responses to requests
- Made environment reset completely recreate the game object.
- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack.
- Changed the data manipulation scenario to include a second green agent on client 1.
- Refactored actions and observations to be configurable via object name, instead of UUID.
- Made database patch correctly take 2 timesteps instead of being immediate
- Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING
- Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function.
- Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config.
- Made observation space flattening optional (on by default). To turn off for an agent, change the `agent_settings.flatten_obs` setting in the config.
- Added support for SQL INSERT command.
- Added ability to log each agent's action choices in each step to a JSON file.
- Removal of Link bandwidth hardcoding. This can now be configured via the network configuraiton yaml. Will default to 100 if not present.
- Added NMAP application to all host and layer-3 network nodes.
### Bug Fixes
- ACL rules were not resetting on episode reset.
- ACLs were not showing up correctly in the observation space.
- Blue agent's ACL actions were being applied against the wrong IP addresses
- Deleted files and folders did not reset correctly on episode reset.
- Service health status was using the actual health state instead of the visible health state
- Database file health status was using the incorrect value for negative rewards
- Preventing file actions from reaching their intended file
- The data manipulation attack was triggered at episode start.
- FTP STOR stored an additional copy on the client machine's filesystem
- The red agent acted to early
- Order of service health state
- Starting a node didn't start the services on it
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
- The use of NODE_FILE_CHECKHASH and NODE_FOLDER_CHECKHASH in the current release is marked as 'Not Implemented'.
## [3.2.0] - 2024-07-18
### Added
- Network Hardware - Added base hardware module with NIC, SwitchPort, Node, and Link. Nodes have
fundamental services like ARP, ICMP, and PCAP running them by default.
- Network Transmission - Modelled OSI Model layers 1 through to 5 with various classes for creating network frames and
transmitting them from a Service/Application, down through the layers, over the wire, and back up through the layers to
a Service/Application another machine.
- Introduced `Router` and `Switch` classes to manage networking routes more effectively.
- Added `ACLRule` and `RouteTableEntry` classes as part of the `Router`.
- New `.show()` methods in all network component classes to inspect the state in either plain text or markdown formats.
- Added `Computer` and `Server` class to better differentiate types of network nodes.
- Integrated a new Use Case 2 network into the system.
- New unit tests to verify routing between different subnets using `.ping()`.
- system - Added the core structure of Application, Services, and Components. Also added a SoftwareManager and
SessionManager.
- Permission System - each action can define criteria that will be used to permit or deny agent actions.
- File System - ability to emulate a node's file system during a simulation
- Example notebooks - There are 5 jupyter notebook which walk through using PrimAITE
1. Training a Stable Baselines 3 agent
2. Training a single agent system using Ray RLLib
3. Training a multi-agent system Ray RLLib
4. Data manipulation end to end demonstration
5. Data manipulation scenario with customised red agents
- Database:
- `DatabaseClient` and `DatabaseService` created to allow emulation of database actions
- Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup
- Red Agent Services:
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database). The attack runs in stages with a random, configurable probability of succeeding.
- `DataManipulationAgent` runs the Data Manipulator Bot according to a configured start step, frequency and variance.
- DNS Services: `DNSClient` and `DNSServer`
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
- NTP Services: `NTPClient` and `NTPServer`
- **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic.
- **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required.
- **Enhanced Frame Reception**: The `receive_frame` method in `RouterNIC` is tailored to handle frames based on Layer 2 (Ethernet) checks, focusing on MAC address-based routing and broadcast frame acceptance.
- **Subnet-Wide Broadcasting for Services and Applications**: Implemented the ability for services and applications to conduct broadcasts across an entire IPv4 subnet within the network simulation framework.
- Introduced the `NetworkInterface` abstract class to provide a common interface for all network interfaces. Subclasses are divided into two main categories: `WiredNetworkInterface` and `WirelessNetworkInterface`, each serving as an abstract base class (ABC) for more specific interface types. Under `WiredNetworkInterface`, the subclasses `NIC` and `SwitchPort` were added. For wireless interfaces, `WirelessNIC` and `WirelessAccessPoint` are the subclasses under `WirelessNetworkInterface`.
- Added `Layer3Interface` as an abstract base class for networking functionalities at layer 3, including IP addressing and routing capabilities. This class is inherited by `NIC`, `WirelessNIC`, and `WirelessAccessPoint` to provide them with layer 3 capabilities, facilitating their role in both wired and wireless networking contexts with IP-based communication.
- Created the `ARP` and `ICMP` service classes to handle Address Resolution Protocol operations and Internet Control Message Protocol messages, respectively, with `RouterARP` and `RouterICMP` for router-specific implementations.
- Created `HostNode` as a subclass of `Node`, extending its functionality with host-specific services and applications. This class is designed to represent end-user devices like computers or servers that can initiate and respond to network communications.
- Introduced a new `IPV4Address` type in the Pydantic model for enhanced validation and auto-conversion of IPv4 addresses from strings using an `ipv4_validator`.
- Comprehensive documentation for the Node and its network interfaces, detailing the operational workflow from frame reception to application-level processing.
- Detailed descriptions of the Session Manager and Software Manager functionalities, including their roles in managing sessions, software services, and applications within the simulation.
- Documentation for the Packet Capture (PCAP) service and SysLog functionality, highlighting their importance in logging network frames and system events, respectively.
- Expanded documentation on network devices such as Routers, Switches, Computers, and Switch Nodes, explaining their specific processing logic and protocol support.
- **Firewall Node**: Introduced the `Firewall` class extending the functionality of the existing `Router` class. The `Firewall` class incorporates advanced features to scrutinize, direct, and filter traffic between various network zones, guided by predefined security rules and policies. Key functionalities include:
- Access Control Lists (ACLs) for traffic filtering based on IP addresses, protocols, and port numbers.
- Network zone segmentation for managing traffic across external, internal, and DMZ (De-Militarized Zone) networks.
- Interface configuration to establish connectivity and define network parameters for external, internal, and DMZ interfaces.
- Protocol and service management to oversee traffic and enforce security policies.
- Dynamic traffic processing and filtering to ensure network security and integrity.
- `AirSpace` class to simulate wireless communications, managing wireless interfaces and facilitating the transmission of frames within specified frequencies.
- `AirSpaceFrequency` enum for defining standard wireless frequencies, including 2.4 GHz and 5 GHz bands, to support realistic wireless network simulations.
- `WirelessRouter` class, extending the `Router` class, to incorporate wireless networking capabilities alongside traditional wired connections. This class allows the configuration of wireless access points with specific IP settings and operating frequencies.
- Documentation Updates:
- Examples include how to set up PrimAITE session via config
- Examples include how to create nodes and install software via config
- Examples include how to set up PrimAITE session via Python
- Examples include how to create nodes and install software via Python
- Added missing ``DoSBot`` documentation page
- Added diagrams where needed to make understanding some things easier
- Templated parts of the documentation to prevent unnecessary repetition and for easier maintaining of documentation
- Separated documentation pages of some items i.e. client and server software were on the same pages - which may make things confusing
- Configuration section at the bottom of the software pages specifying the configuration options available (and which ones are optional)
- Ability to add ``Firewall`` node via config
- Ability to add ``Router`` routes via config
- Ability to add ``Router``/``Firewall`` ``ACLRule`` via config
- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events.
- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE".
- Router-specific SessionManager Implementation: Introduced a specialized version of the SessionManager tailored for router operations. This enhancement enables the SessionManager to determine the routing path by consulting the route table.
- Action penalty is a reward component that applies a negative reward for doing any action other than DONOTHING
- Application configuration actions for RansomwareScript, DatabaseClient, and DoSBot applications
- Ability to configure how long it takes to apply the service fix action
- Terminal service using SSH
- Airspaces now track the amount of data being transmitted, viewable using the `show_bandwidth_load` method
- Tests to verify that airspace bandwidth is applied correctly and can be configured via YAML
- Agent logging for agents' internal decision logic
- Action masking in all PrimAITE environments
### Changed
- Integrated the RouteTable into the Routers frame processing.
- Frames are now dropped when their TTL reaches 0
- **NIC Functionality Update**: Updated the Network Interface Card (`NIC`) functionality to support Layer 3 (L3) broadcasts.
- **Layer 3 Broadcast Handling**: Enhanced the existing `NIC` classes to correctly process and handle Layer 3 broadcasts. This update allows devices using standard NICs to effectively participate in network activities that involve L3 broadcasting.
- **Improved Frame Reception Logic**: The `receive_frame` method of the `NIC` class has been updated to include additional checks and handling for L3 broadcasts, ensuring proper frame processing in a wider range of network scenarios.
- Standardised the way network interfaces are accessed across all `Node` subclasses (`HostNode`, `Router`, `Switch`) by maintaining a comprehensive `network_interface` attribute. This attribute captures all network interfaces by their port number, streamlining the management and interaction with network interfaces across different types of nodes.
- Refactored all tests to utilise new `Node` subclasses (`Computer`, `Server`, `Router`, `Switch`) instead of creating generic `Node` instances and manually adding network interfaces. This change aligns test setups more closely with the intended use cases and hierarchies within the network simulation framework.
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
- Integration of NMNE capturing functionality within the `NICObservation` class.
- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario
### Removed
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`
- Removed legacy training modules
- Removed tests for legacy code
- Application registry was moved to the `Application` class and now updates automatically when Application is subclassed
- Databases can no longer respond to request while performing a backup
- Application install no longer accepts an `ip_address` parameter
- Application install action can now be used on all applications
- Actions have additional logic for checking validity
- Frame `size` attribute now includes both core size and payload size in bytes
- The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`
- Tidied up CHANGELOG
- Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
- Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
### Fixed
- Addressed network transmission issues that previously allowed ARP requests to be incorrectly routed and repeated across different subnets. This fix ensures ARP requests are correctly managed and confined to their appropriate network segments.
- Resolved problems in `Node` and its subclasses where the default gateway configuration was not properly utilized for communications across different subnets. This correction ensures that nodes effectively use their configured default gateways for outbound communications to other network segments, thereby enhancing the network's routing functionality and reliability.
- Network Interface Port name/num being set properly for sys log and PCAP output.
- Links and airspaces can no longer transmit data if this would exceed their bandwidth
## [3.1.0] - 2024-06-25
### Added
- Observations for traffic amounts on host network interfaces
- NMAP application network discovery, including ping scan and port scan
- NMAP actions
- Automated adding copyright notices to source files
- More file types
- `show` method to files
- `model_dump` methods to network enums to enable better logging
### Changed
- Updated file system actions to stop failures when creating duplicate files
- Improved parsing of ACL add rule actions to make some parameters optional
### Fixed
- Fixed database client uninstall failing due to persistent connections
- Fixed packet storm when pinging broadcast addresses
## [3.0.0] - 2024-06-10
### Added
- New simulation module
- Multi agent reinforcement learning support
- File system class to manage files and folders
- Software for nodes that can have its own behaviour
- Software classes to model FTP, Postgres databases, web traffic, NTP
- Much more detailed network simulation including packets, links, and network interfaces
- More node types: host, computer, server, router, switch, wireless router, and firewalls
- Network Hardware - NIC, SwitchPort, Node, and Link. Nodes have fundamental services like ARP, ICMP, and PCAP running them by default.
- Malicious network event detection
- New `game` module for managing agents
- ACL rule wildcard masking
- Network broadcasting
- Wireless transmission
- More detailed documentation
- Example jupyter notebooks to demonstrate new functionality
- More reward components
- Packet capture logs
- Node system logs
- Per-step full simulation state log
- Attack randomisation with respect to timing and attack source
- Ability to set log level via CLI
- Ability to vary the YAML configuration per-episode
- Developer CLI tools for enhanced debugging (with `primaite dev-mode enable`)
- `show` function to many simulation objects to inspect their current state
### Changed
- Decoupled the environment from the simulation by adding the `game` interface layer
- Made agents share a common base class
- Added more actions
- Made all agents use CAOS actions, including red and green agents
- Reworked YAML configuration file schema
- Reworked the reward system to be component-based
- Changed agent logs to create a JSON output instead of CSV with more detailed action information
- Made observation space flattening optional
- Made all logging optional
- Agent actions now provide responses with a success code
### Removed
- Legacy simulation modules
- Legacy training modules
- Tests for legacy code
- Hardcoded IERs and PoL, traffic generation is now handled by agents and software
- Inbuilt agent training scripts
## [2.0.0] - 2023-07-26
### Added
- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE.
- Application Directories to enable PrimAITE as a Python package with predefined directories for storage.
- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib.
- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`.
- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options.
- Session loading to revisit previously run sessions for SB3 Agents.
- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface.
- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs:
1. Session Metadata
2. Results
3. Diagrams
4. Saved agents (training checkpoints and a final trained agent).
- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup.
- Benchmarking of PrimAITE performance, showcasing session and step durations for reference.
- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`.
- Command Line Interface (CLI) for easy access and streamlined usage of PrimAITE.
- Application Directories to enable PrimAITE as a Python package with predefined directories for storage.
- Support for Ray Rllib, allowing training of PPO and A2C agents using Stable Baselines3 and Ray RLlib.
- Random Red Agent to train the blue agent against, with options for randomised Red Agent `POL` and `IER`.
- Repeatability of sessions through seed settings, and deterministic or stochastic evaluation options.
- Session loading to revisit previously run sessions for SB3 Agents.
- Agent Session Classes (`AgentSessionABC` and `HardCodedAgentSessionABC`) to standardise agent training with a common interface.
- Standardised Session Output in a structured format in the user's app sessions directory, providing four types of outputs: Session Metadata, Results, Diagrams, Trained agents.
- Configurable Observation Space managed by the `ObservationHandler` class for a more flexible observation space setup.
- Benchmarking of PrimAITE performance, showcasing session and step durations for reference.
- Documentation overhaul, including automatic API and test documentation with recursive Sphinx auto-summary, using the Furo theme for responsive light/dark theme, and enhanced navigation with `sphinx-code-tabs` and `sphinx-copybutton`.
### Changed
- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions.
- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`.
- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names.
- Docs and Tests now sit outside the `src/` directory.
- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages.
- All dependencies are now defined in the `pyproject.toml` file.
- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each.
- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management.
- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations.
- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority.
- Action Space updated to discrete spaces, introducing a new `ANY` action space option for combined `NODE` and `ACL` actions.
- Improved `Node` attribute naming convention for consistency, now adhering to `Pascal Case`.
- Package Structure has been refactored for better build, distribution, and installation, with all source code now in the `src/` directory, and the `PRIMAITE` Python package renamed to `primaite` to adhere to PEP-8 Package & Module Names.
- Docs and Tests now sit outside the `src/` directory.
- Non-python files (example config files, Jupyter notebooks, etc.) now sit inside a `*/_package_data/` directory in their respective sub-packages.
- All dependencies are now defined in the `pyproject.toml` file.
- Introduced individual configuration for the number of episodes and time steps for training and evaluation sessions, with separate config values for each.
- Decoupled the lay down config file from the training config, allowing more flexibility in configuration management.
- Updated `Transactions` to only report pre-action observation, improving the CSV header and providing more human-readable descriptions for columns relating to observations.
- Changes to `AccessControlList`, where the `acl` dictionary is now a list to accommodate changes to ACL action space and positioning of `ACLRules` inside the list to signal their level of priority.
### Fixed
- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation.
- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes.
- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands.
- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making.
- Various bug fixes, including Green IERs separation, correct clearing of links in the reference environment, and proper reward calculation.
- Logic to check if a node is OFF before executing actions on the node by the blue agent, preventing erroneous state changes.
- Improved functionality of Resetting a Node, adding "SHUTTING DOWN" and "BOOTING" operating states for more reliable reset commands.
- Corrected the order of actions in the `Primaite` env to ensure the blue agent uses the current state for decision-making.
## [1.1.1] - 2023-06-27
### Bug Fixes
* Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by:
* Implementing a reference copy of all green IERs (`self.green_iers_reference`).
* Clearing the traffic on reference IERs at the same time as the live IERs.
* Passing the `green_iers_reference` to the `apply_iers` function at the reference stage.
* Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`.
* Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment.
* Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes.
* Removing the unnecessary "Reapply PoL and IERs" action from the step function.
* Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function.
### Fixed
- Fixed bug whereby 'reference' environment links reach bandwidth capacity and are never cleared due to green & red IERs being applied to them. This bug had a knock-on effect that meant IERs were being blocked based on the full capacity of links on the reference environment which was not correct; they should only be based on the link capacity of the 'live' environment. This fix has been addressed by:
- Implementing a reference copy of all green IERs (`self.green_iers_reference`).
- Clearing the traffic on reference IERs at the same time as the live IERs.
- Passing the `green_iers_reference` to the `apply_iers` function at the reference stage.
- Passing the `green_iers_reference` as an additional argument to `calculate_reward_function`.
- Updating the green IERs section of the `calculate_reward_function` to now take into account both the green reference IERs and live IERs. The `green_ier_blocked` reward is only applied if the IER is blocked in the live environment but is running in the reference environment.
- Re-ordering the actions taken as part of the step function to ensure the blue action happens first before other changes.
- Removing the unnecessary "Reapply PoL and IERs" action from the step function.
- Moving the deep-copy of nodes and links to below the "Implement blue action" stage of the step function.
## [1.1.0] - 2023-03-13
### Added
* The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function.
* The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name.
* Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components:
* Blue agent observation space;
* Blue agent action space;
* Reward function;
* Node pattern-of-life.
* The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network).
* New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile.
* NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability.
* The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value.
* The User Guide has been updated to reflect all the above changes.
- The user can now initiate either a TRAINING session or an EVALUATION (test) session with the Stable Baselines 3 (SB3) agents via the config_main.yaml file. During evaluation/testing, the agent policy will be fixed (no longer learning) and subjected to the SB3 `evaluate_policy()` function.
- The user can choose whether a saved agent is loaded into the session (with reference to a URL) via the `config_main.yaml` file. They specify a Boolean true/false indicating whether a saved agent should be loaded, and specify the URL and file name.
- Active and Service nodes now possess a new "File System State" attribute. This attribute is permitted to have the states GOOD, CORRUPT, DESTROYED, REPAIRING, and RESTORING. This new feature affects the following components:
- Blue agent observation space;
- Blue agent action space;
- Reward function;
- Node pattern-of-life.
- The Red Agent node pattern-of-life has been enhanced so that node PoL is triggered by an 'initiator'. The initiator is either DIRECT (state change is applied to the node without any conditions), IER (state change is applied to the node based on IER entry condition), or SERVICE (state change is applied to the node based on a service state condition on the same node or a different node within the network).
- New default config named "config_5_DATA_MANIPULATION.yaml" and associated Training Use Case Profile.
- NodeStateInstruction has been split into `NodeStateInstructionGreen` and `NodeStateInstructionRed` to reflect the changes within the red agent pattern-of-life capability.
- The reward function has been enhanced so that node attribute states of resetting, patching, repairing, and restarting contribute to the overall reward value.
- The User Guide has been updated to reflect all the above changes.
### Changed
* "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing.
* "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
* "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
* "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file.
* Updates to Transactions.
- "config_1_DDOS_BASIC.yaml" modified to make it more simplistic to aid evaluation testing.
- "config_2_DDOS_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
- "config_3_DOS_VERY_BASIC.yaml" updated to reflect the addition of the File System State and the Red Agent node pattern-of-life enhancement.
- "config_UNIT_TEST.yaml" is a copy of the new "config_5_DATA_MANIPULATION.yaml" file.
- Updates to Transactions.
### Fixed
* Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default.
[unreleased]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/compare/v2.0.0...HEAD
[2.0.0]: https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/tag/v2.0.0
- Fixed "config_2_DDOS_BASIC.yaml" by adding another ACL rule to allow traffic to flow from Node 9 to Node 3. Previously, there was no rule, so one of the green IERs could not flow by default.

View File

@@ -13,9 +13,6 @@
* [Fork the repository](https://github.com/{todo:fill in URL}/PrimAITE/fork).
* Install the pre-commit hook with `pre-commit install`.
* Implement the bug fix.
* Update documentation where applicable.
* Update the **UNRELEASED** section of the [CHANGELOG.md](CHANGELOG.md) file
* Write a suitable test/tests.
* Commit the bug fix to the dev branch on your fork. If the bug has an open issue under [Issues](https://github.com/{todo:fill in URL}/PrimAITE/issues), reference the issue in the commit message (e.g. #1 references issue 1).
* Submit a pull request from your dev branch to the {todo:fill in URL}/PrimAITE dev branch. Again, if the bug has an open issue under [Issues](https://github.com/{todo:fill in URL}/PrimAITE/issues), reference the issue in the pull request description.

3
_config.yml Normal file
View File

@@ -0,0 +1,3 @@
# Used by nbmake to change build pipeline notebook timeout
execute:
timeout: 600

View File

@@ -53,3 +53,30 @@ The number of time steps required to occur in order for the node to cycle from `
Optional. Default value is ``3``.
The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``.
``users``
---------
The list of pre-existing users that are additional to the default admin user (``username=admin``, ``password=admin``).
Additional users are configured as an array nd must contain a ``username``, ``password``, and can contain an optional
boolean ``is_admin``.
Example of adding two additional users to a node:
.. code-block:: yaml
simulation:
network:
nodes:
- hostname: client_1
type: computer
ip_address: 192.168.10.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
users:
- username: jane.doe
password: '1234'
is_admin: true
- username: john.doe
password: password_1
is_admin: false

View File

@@ -97,8 +97,8 @@ Node Behaviours/Functions
- **receive_frame()**: Handles the processing of incoming network frames.
- **apply_timestep()**: Advances the state of the node according to the simulation timestep.
- **power_on()**: Initiates the node, enabling all connected Network Interfaces and starting all Services and
Applications, taking into account the `start_up_duration`.
- **power_off()**: Stops the node's operations, adhering to the `shut_down_duration`.
Applications, taking into account the ``start_up_duration``.
- **power_off()**: Stops the node's operations, adhering to the ``shut_down_duration``.
- **ping()**: Sends ICMP echo requests to a specified IP address to test connectivity.
- **has_enabled_network_interface()**: Checks if the node has any network interfaces enabled, facilitating network
communication.
@@ -109,3 +109,205 @@ Node Behaviours/Functions
The Node class handles installation of system software, network connectivity, frame processing, system logging, and
power states. It establishes baseline functionality while allowing subclassing to model specific node types like hosts,
routers, firewalls etc. The flexible architecture enables composing complex network topologies.
User, UserManager, and UserSessionManager
=========================================
The ``base.py`` module also includes essential classes for managing users and their sessions within the PrimAITE
simulation. These are the ``User``, ``UserManager``, and ``UserSessionManager`` classes. The base ``Node`` class comes
with ``UserManager``, and ``UserSessionManager`` classes pre-installed.
User Class
----------
The ``User`` class represents a user in the system. It includes attributes such as ``username``, ``password``,
``disabled``, and ``is_admin`` to define the user's credentials and status.
Example Usage
^^^^^^^^^^^^^
Creating a user:
.. code-block:: python
user = User(username="john_doe", password="12345")
UserManager Class
-----------------
The ``UserManager`` class handles user management tasks such as creating users, authenticating them, changing passwords,
and enabling or disabling user accounts. It maintains a dictionary of users and provides methods to manage them
effectively.
Example Usage
^^^^^^^^^^^^^
Creating a ``UserManager`` instance and adding a user:
.. code-block:: python
user_manager = UserManager()
user_manager.add_user(username="john_doe", password="12345")
Authenticating a user:
.. code-block:: python
user = user_manager.authenticate_user(username="john_doe", password="12345")
UserSessionManager Class
------------------------
The ``UserSessionManager`` class manages user sessions, including local and remote sessions. It handles session creation,
timeouts, and provides methods for logging users in and out.
Example Usage
^^^^^^^^^^^^^
Creating a ``UserSessionManager`` instance and logging a user in locally:
.. code-block:: python
session_manager = UserSessionManager()
session_id = session_manager.local_login(username="john_doe", password="12345")
Logging a user out:
.. code-block:: python
session_manager.local_logout()
Practical Examples
------------------
Below are unit tests which act as practical examples illustrating how to use the ``User``, ``UserManager``, and
``UserSessionManager`` classes within the context of a client-server network simulation.
Setting up a Client-Server Network
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
from typing import Tuple
from uuid import uuid4
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
@pytest.fixture(scope="function")
def client_server_network() -> Tuple[Computer, Server, Network]:
network = Network()
client = Computer(
hostname="client",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client.power_on()
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
network.connect(client.network_interface[1], server.network_interface[1])
return client, server, network
Local Login Success
^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
Local Login Failure
^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_local_login_failure(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert not client.user_session_manager.local_user_logged_in
Adding a New User and Successful Local Login
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_new_user_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_manager.add_user(username="jane.doe", password="12345")
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
Clearing Previous Login on New Local Login
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_new_local_login_clears_previous_login(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
client.user_manager.add_user(username="jane.doe", password="12345")
new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "jane.doe"
assert new_session_id != current_session_id
Persistent Login for the Same User
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
def test_new_local_login_attempt_same_uses_persists(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
new_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
assert new_session_id == current_session_id

View File

@@ -49,3 +49,5 @@ fundamental network operations:
5. **NTP (Network Time Protocol) Client:** Synchronises the host's clock with network time servers.
6. **Web Browser:** A simulated application that allows the host to request and display web content.
7. **Terminal:** A simulated service that allows the host to connect to remote hosts and execute commands.

View File

@@ -0,0 +1,173 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _Terminal:
Terminal
========
The ``Terminal.py`` class provides a generic terminal simulation, by extending the base Service class within PrimAITE. The aim of this is to act as the primary entrypoint for Nodes within the environment.
Overview
--------
The Terminal service uses Secure Socket (SSH) as the communication method between terminals. They operate on port 22, and are part of the services automatically
installed on Nodes when they are instantiated.
Key capabilities
================
- Ensures packets are matched to an existing session
- Simulates common Terminal processes/commands.
- Leverages the Service base class for install/uninstall, status tracking etc.
Usage
=====
- Pre-Installs on any `Node` (component with the exception of `Switches`).
- Terminal Clients connect, execute commands and disconnect from remote nodes.
- Ensures that users are logged in to the component before executing any commands.
- Service runs on SSH port 22 by default.
Implementation
==============
- Manages remote connections in a dictionary by session ID.
- Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate.
- Extends Service class.
- A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
Usage
=====
The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node.
Python
""""""
.. code-block:: python
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
client = Computer(
hostname="client",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
operating_state=NodeOperatingState.ON,
)
terminal: Terminal = client.software_manager.software.get("Terminal")
Creating Remote Terminal Connection
"""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
Executing a basic application install command
"""""""""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"])
Creating a folder on a remote node
""""""""""""""""""""""""""""""""
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.execute(["file_system", "create", "folder", "downloads"])
Disconnect from Remote Node
.. code-block:: python
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
term_a_term_b_remote_connection.disconnect()

22
run_test_and_coverage.py Normal file
View File

@@ -0,0 +1,22 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import subprocess
import sys
from typing import Any
def run_command(command: Any):
"""Runs a command and returns the exit code."""
result = subprocess.run(command, shell=True)
if result.returncode != 0:
sys.exit(result.returncode)
# Run pytest with coverage
run_command(
"coverage run -m --source=primaite pytest -v -o junit_family=xunit2 "
"--junitxml=junit/test-results.xml --cov-fail-under=80"
)
# Generate coverage reports if tests passed
run_command("coverage xml -o coverage.xml -i")
run_command("coverage html -d htmlcov -i")

View File

@@ -129,6 +129,10 @@ agents:
simulation:
network:
nmne_config:
capture_nmne: true
nmne_capture_keywords:
- DELETE
nodes:
- hostname: client
type: computer

View File

@@ -294,7 +294,7 @@ class ConfigureDoSBotAction(AbstractAction):
"""Action which sets config parameters for a DoS bot on a node."""
class _Opts(BaseModel):
"""Schema for options that can be passed to this option."""
"""Schema for options that can be passed to this action."""
model_config = ConfigDict(extra="forbid")
target_ip_address: Optional[str] = None

View File

@@ -22,8 +22,6 @@ class ProbabilisticAgent(AbstractScriptedAgent):
"""Strict validation."""
action_probabilities: Dict[int, float]
"""Probability to perform each action in the action map. The sum of probabilities should sum to 1."""
random_seed: Optional[int] = None
"""Random seed. If set, each episode the agent will choose the same random sequence of actions."""
# TODO: give the option to still set a random seed, but have it vary each episode in a predictable way
# for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed.
@@ -59,17 +57,18 @@ class ProbabilisticAgent(AbstractScriptedAgent):
num_actions = len(action_space.action_map)
settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}}
# If seed not specified, set it to None so that numpy chooses a random one.
settings.setdefault("random_seed")
# The random number seed for np.random is dependent on whether a random number seed is set
# in the config file. If there is one it is processed by set_random_seed() in environment.py
# and as a consequence the the sequence of rng_seed's used here will be repeatable.
self.settings = ProbabilisticAgent.Settings(**settings)
self.rng = np.random.default_rng(self.settings.random_seed)
rng_seed = np.random.randint(0, 65535)
self.rng = np.random.default_rng(rng_seed)
# convert probabilities from
self.probabilities = np.asarray(list(self.settings.action_probabilities.values()))
super().__init__(agent_name, action_space, observation_space, reward_function)
self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}")
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -18,7 +18,7 @@ from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState, UserManager
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -26,7 +26,7 @@ from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
@@ -46,6 +46,7 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -59,6 +60,7 @@ SERVICE_TYPES_MAPPING = {
"FTPServer": FTPServer,
"NTPClient": NTPClient,
"NTPServer": NTPServer,
"Terminal": Terminal,
}
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
@@ -72,6 +74,8 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
seed: int = None
"""Random number seed for RNGs."""
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]
@@ -266,9 +270,12 @@ class PrimaiteGame:
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])
# Set the NMNE capture config
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
for node_cfg in nodes_cfg:
n_type = node_cfg["type"]
new_node = None
if n_type == "computer":
new_node = Computer(
hostname=node_cfg["hostname"],
@@ -318,6 +325,11 @@ class PrimaiteGame:
msg = f"invalid node type {n_type} in config"
_LOGGER.error(msg)
raise ValueError(msg)
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
@@ -535,10 +547,7 @@ class PrimaiteGame:
# Validate that if any agents are sharing rewards, they aren't forming an infinite loop.
game.setup_reward_sharing()
# Set the NMNE capture config
set_nmne_config(network_config.get("nmne_config", {}))
game.update_agents(game.get_sim_state())
return game
def setup_reward_sharing(self):

View File

@@ -101,7 +101,6 @@
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n",
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
]
@@ -135,8 +134,7 @@
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
"results = algo.train()"
]
},
{
@@ -191,8 +189,7 @@
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
"results = algo.train()"
]
}
],
@@ -212,7 +209,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,209 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Terminal Processing\n",
"\n",
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook serves as a guide on the functionality and use of the new Terminal simulation component.\n",
"\n",
"The Terminal service comes pre-installed on most Nodes (The exception being Switches, as these are currently dumb). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.simulator.system.services.terminal.terminal import Terminal\n",
"from primaite.simulator.network.container import Network\n",
"from primaite.simulator.network.hardware.nodes.host.computer import Computer\n",
"from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript\n",
"from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection\n",
"\n",
"def basic_network() -> Network:\n",
" \"\"\"Utility function for creating a default network to demonstrate Terminal functionality\"\"\"\n",
" network = Network()\n",
" node_a = Computer(hostname=\"node_a\", ip_address=\"192.168.0.10\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
" node_a.power_on()\n",
" node_b = Computer(hostname=\"node_b\", ip_address=\"192.168.0.11\", subnet_mask=\"255.255.255.0\", start_up_duration=0)\n",
" node_b.power_on()\n",
" network.connect(node_a.network_interface[1], node_b.network_interface[1])\n",
" return network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The terminal can be accessed from a `Node` via the `software_manager` as demonstrated below. \n",
"\n",
"In the example, we have a basic network consisting of two computers, connected to form a basic network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"network: Network = basic_network()\n",
"computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n",
"terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n",
"computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n",
"terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To be able to send commands from `node_a` to `node_b`, you will need to `login` to `node_b` first, using valid user credentials. In the example below, we are remotely logging in to the 'admin' account on `node_b`, from `node_a`. \n",
"If you are not logged in, any commands sent will be rejected by the remote.\n",
"\n",
"Remote Logins return a RemoteTerminalConnection object, which can be used for sending commands to the remote node. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Login to the remote (node_b) from local (node_a)\n",
"term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username=\"admin\", password=\"Admin123!\", ip_address=\"192.168.0.11\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can view all active connections to a terminal through use of the `show()` method"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminal_b.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The new connection object allows us to forward commands to be executed on the target node. The example below demonstrates how you can remotely install an application on the target node."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.software_manager.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code block below demonstrates how the Terminal class allows the user of `terminal_a`, on `computer_a`, to send a command to `computer_b` to create a downloads folder. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display the current state of the file system on computer_b\n",
"computer_b.file_system.show()\n",
"\n",
"# Send command\n",
"term_a_term_b_remote_connection.execute([\"file_system\", \"create\", \"folder\", \"downloads\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The resultant call to `computer_b.file_system.show()` shows that the new folder has been created."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"computer_b.file_system.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When finished, the connection can be closed by calling the `disconnect` function of the Remote Client object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display active connection\n",
"terminal_a.show()\n",
"terminal_b.show()\n",
"\n",
"term_a_term_b_remote_connection.disconnect()\n",
"\n",
"terminal_a.show()\n",
"\n",
"terminal_b.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -24,14 +24,11 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"\n",
@@ -72,7 +69,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training\n",
"#### Start the training\n",
"This example will save outputs to a default Ray directory and use mostly default settings."
]
},
@@ -82,13 +79,8 @@
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128},\n",
" ),\n",
" param_space=config\n",
").fit()"
"algo = config.build()\n",
"results = algo.train()"
]
}
],

View File

@@ -17,12 +17,10 @@
"metadata": {},
"outputs": [],
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"\n",
@@ -64,7 +62,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training"
"#### Start the training"
]
},
{
@@ -73,13 +71,8 @@
"metadata": {},
"outputs": [],
"source": [
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
"algo = config.build()\n",
"results = algo.train()\n"
]
}
],

View File

@@ -1,5 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
import random
import sys
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
@@ -17,6 +19,36 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
# Check torch is installed
try:
import torch as th
except ModuleNotFoundError:
_LOGGER.debug("Torch not available for importing")
def set_random_seed(seed: int) -> Union[None, int]:
"""
Set random number generators.
:param seed: int
"""
if seed is None or seed == -1:
return None
elif seed < -1:
raise ValueError("Invalid random number seed")
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# Seed the RNG for all devices (both CPU and CUDA)
# if torch not installed don't set random seed.
if sys.modules["torch"]:
th.manual_seed(seed)
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
return seed
class PrimaiteGymEnv(gymnasium.Env):
"""
@@ -31,6 +63,9 @@ class PrimaiteGymEnv(gymnasium.Env):
super().__init__()
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
"""Get RNG seed from config file. NB: Must be before game instantiation."""
self.seed = set_random_seed(self.seed)
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
@@ -42,6 +77,8 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
@@ -108,6 +145,8 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if seed is not None:
set_random_seed(seed)
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:

View File

@@ -63,6 +63,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
@@ -176,6 +177,7 @@ class PrimaiteRayEnv(gymnasium.Env):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
super().reset() # Ensure PRNG seed is set everywhere
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}

View File

@@ -6,12 +6,11 @@ import secrets
from abc import ABC, abstractmethod
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Optional, TypeVar, Union
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validate_call
import primaite.simulator.network.nmne
from primaite import getLogger
from primaite.exceptions import NetworkError
from primaite.interface.request import RequestResponse
@@ -20,17 +19,10 @@ from primaite.simulator.core import RequestFormat, RequestManager, RequestPermis
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.nmne import (
CAPTURE_BY_DIRECTION,
CAPTURE_BY_IP_ADDRESS,
CAPTURE_BY_KEYWORD,
CAPTURE_BY_PORT,
CAPTURE_BY_PROTOCOL,
CAPTURE_NMNE,
NMNE_CAPTURE_KEYWORDS,
)
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.packet_capture import PacketCapture
from primaite.simulator.system.core.session_manager import SessionManager
@@ -38,7 +30,8 @@ from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.processes.process import Process
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.software import IOSoftware
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.simulator.system.software import IOSoftware, Software
from primaite.utils.converters import convert_dict_enum_keys_to_enum_values
from primaite.utils.validators import IPV4Address
@@ -108,8 +101,11 @@ class NetworkInterface(SimComponent, ABC):
pcap: Optional[PacketCapture] = None
"A PacketCapture instance for capturing and analysing packets passing through this interface."
nmne_config: ClassVar[NMNEConfig] = NMNEConfig()
"A dataclass defining malicious network events to be captured."
nmne: Dict = Field(default_factory=lambda: {})
"A dict containing details of the number of malicious network events captured."
"A dict containing details of the number of malicious events captured."
traffic: Dict = Field(default_factory=lambda: {})
"A dict containing details of the inbound and outbound traffic by port and protocol."
@@ -167,8 +163,8 @@ class NetworkInterface(SimComponent, ABC):
"enabled": self.enabled,
}
)
if CAPTURE_NMNE:
state.update({"nmne": {k: v for k, v in self.nmne.items()}})
if self.nmne_config and self.nmne_config.capture_nmne:
state.update({"nmne": self.nmne})
state.update({"traffic": convert_dict_enum_keys_to_enum_values(self.traffic)})
return state
@@ -201,7 +197,7 @@ class NetworkInterface(SimComponent, ABC):
:param inbound: Boolean indicating if the frame direction is inbound. Defaults to True.
"""
# Exit function if NMNE capturing is disabled
if not CAPTURE_NMNE:
if not (self.nmne_config and self.nmne_config.capture_nmne):
return
# Initialise basic frame data variables
@@ -222,27 +218,27 @@ class NetworkInterface(SimComponent, ABC):
frame_str = str(frame.payload)
# Proceed only if any NMNE keyword is present in the frame payload
if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS):
if any(keyword in frame_str for keyword in self.nmne_config.nmne_capture_keywords):
# Start with the root of the NMNE capture structure
current_level = self.nmne
# Update NMNE structure based on enabled settings
if CAPTURE_BY_DIRECTION:
if self.nmne_config.capture_by_direction:
# Set or get the dictionary for the current direction
current_level = current_level.setdefault("direction", {})
current_level = current_level.setdefault(direction, {})
if CAPTURE_BY_IP_ADDRESS:
if self.nmne_config.capture_by_ip_address:
# Set or get the dictionary for the current IP address
current_level = current_level.setdefault("ip_address", {})
current_level = current_level.setdefault(ip_address, {})
if CAPTURE_BY_PROTOCOL:
if self.nmne_config.capture_by_protocol:
# Set or get the dictionary for the current protocol
current_level = current_level.setdefault("protocol", {})
current_level = current_level.setdefault(protocol, {})
if CAPTURE_BY_PORT:
if self.nmne_config.capture_by_port:
# Set or get the dictionary for the current port
current_level = current_level.setdefault("port", {})
current_level = current_level.setdefault(port, {})
@@ -251,8 +247,8 @@ class NetworkInterface(SimComponent, ABC):
keyword_level = current_level.setdefault("keywords", {})
# Increment the count for detected keywords in the payload
if CAPTURE_BY_KEYWORD:
for keyword in NMNE_CAPTURE_KEYWORDS:
if self.nmne_config.capture_by_keyword:
for keyword in self.nmne_config.nmne_capture_keywords:
if keyword in frame_str:
# Update the count for each keyword found
keyword_level[keyword] = keyword_level.get(keyword, 0) + 1
@@ -794,6 +790,650 @@ class Link(SimComponent):
self.current_load = 0.0
class User(SimComponent):
"""
Represents a user in the PrimAITE system.
:ivar username: The username of the user
:ivar password: The password of the user
:ivar disabled: Boolean flag indicating whether the user is disabled
:ivar is_admin: Boolean flag indicating whether the user has admin privileges
"""
username: str
"""The username of the user"""
password: str
"""The password of the user"""
disabled: bool = False
"""Boolean flag indicating whether the user is disabled"""
is_admin: bool = False
"""Boolean flag indicating whether the user has admin privileges"""
num_of_logins: int = 0
"""Counts the number of the User has logged in"""
def describe_state(self) -> Dict:
"""
Returns a dictionary representing the current state of the user.
:return: A dict containing the state of the user
"""
return self.model_dump()
class UserManager(Service):
"""
Manages users within the PrimAITE system, handling creation, authentication, and administration.
:param users: A dictionary of all users by their usernames
:param admins: A dictionary of admin users by their usernames
:param disabled_admins: A dictionary of currently disabled admin users by their usernames
"""
users: Dict[str, User] = {}
def __init__(self, **kwargs):
"""
Initializes a UserManager instanc.
:param username: The username for the default admin user
:param password: The password for the default admin user
"""
kwargs["name"] = "UserManager"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
rm.add_request(
"change_password",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.change_user_password(username=request[0], current_password=request[1], new_password=request[2])
)
),
)
return rm
def describe_state(self) -> Dict:
"""
Returns the state of the UserManager along with the number of users and admins.
:return: A dict containing detailed state information
"""
state = super().describe_state()
state.update({"total_users": len(self.users), "total_admins": len(self.admins) + len(self.disabled_admins)})
state["users"] = {k: v.describe_state() for k, v in self.users.items()}
return state
def show(self, markdown: bool = False):
"""
Display the Users.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
table = PrettyTable(["Username", "Admin", "Disabled"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.sys_log.hostname} User Manager"
for user in self.users.values():
table.add_row([user.username, user.is_admin, user.disabled])
print(table.get_string(sortby="Username"))
@property
def non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and not v.disabled}
@property
def disabled_non_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled non-admin users.
:return: A dictionary with usernames as keys and User objects as values for non-admin, disabled users.
"""
return {k: v for k, v in self.users.items() if not v.is_admin and v.disabled}
@property
def admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all enabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, non-disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and not v.disabled}
@property
def disabled_admins(self) -> Dict[str, User]:
"""
Returns a dictionary of all disabled admin users.
:return: A dictionary with usernames as keys and User objects as values for admin, disabled users.
"""
return {k: v for k, v in self.users.items() if v.is_admin and v.disabled}
def install(self) -> None:
"""Setup default user during first-time installation."""
self.add_user(username="admin", password="admin", is_admin=True, bypass_can_perform_action=True)
def _is_last_admin(self, username: str) -> bool:
return username in self.admins and len(self.admins) == 1
def add_user(
self, username: str, password: str, is_admin: bool = False, bypass_can_perform_action: bool = False
) -> bool:
"""
Adds a new user to the system.
:param username: The username for the new user
:param password: The password for the new user
:param is_admin: Flag indicating if the new user is an admin
:return: True if user was successfully added, False otherwise
"""
if not bypass_can_perform_action and not self._can_perform_action():
return False
if username in self.users:
self.sys_log.info(f"{self.name}: Failed to create new user {username} as this user name already exists")
return False
user = User(username=username, password=password, is_admin=is_admin)
self.users[username] = user
self.sys_log.info(f"{self.name}: Added new {'admin' if is_admin else 'user'}: {username}")
return True
def authenticate_user(self, username: str, password: str) -> Optional[User]:
"""
Authenticates a user's login attempt.
:param username: The username of the user trying to log in
:param password: The password provided by the user
:return: The User object if authentication is successful, None otherwise
"""
if not self._can_perform_action():
return None
user = self.users.get(username)
if user and not user.disabled and user.password == password:
self.sys_log.info(f"{self.name}: User authenticated: {username}")
return user
self.sys_log.info(f"{self.name}: Authentication failed for: {username}")
return None
def change_user_password(self, username: str, current_password: str, new_password: str) -> bool:
"""
Changes a user's password.
:param username: The username of the user changing their password
:param current_password: The current password of the user
:param new_password: The new password for the user
:return: True if the password was changed successfully, False otherwise
"""
if not self._can_perform_action():
return False
user = self.users.get(username)
if user and user.password == current_password:
user.password = new_password
self.sys_log.info(f"{self.name}: Password changed for {username}")
return True
self.sys_log.info(f"{self.name}: Password change failed for {username}")
return False
def disable_user(self, username: str) -> bool:
"""
Disables a user account, preventing them from logging in.
:param username: The username of the user to disable
:return: True if the user was disabled successfully, False otherwise
"""
if not self._can_perform_action():
return False
if username in self.users and not self.users[username].disabled:
if self._is_last_admin(username):
self.sys_log.info(f"{self.name}: Cannot disable User {username} as they are the only enabled admin")
return False
self.users[username].disabled = True
self.sys_log.info(f"{self.name}: User disabled: {username}")
return True
self.sys_log.info(f"{self.name}: Failed to disable user: {username}")
return False
def enable_user(self, username: str) -> bool:
"""
Enables a previously disabled user account.
:param username: The username of the user to enable
:return: True if the user was enabled successfully, False otherwise
"""
if username in self.users and self.users[username].disabled:
self.users[username].disabled = False
self.sys_log.info(f"{self.name}: User enabled: {username}")
return True
self.sys_log.info(f"{self.name}: Failed to enable user: {username}")
return False
class UserSession(SimComponent):
"""
Represents a user session on the Node.
This class manages the state of a user session, including the user, session start, last active step,
and end step. It also indicates whether the session is local.
:ivar user: The user associated with this session.
:ivar start_step: The timestep when the session was started.
:ivar last_active_step: The last timestep when the session was active.
:ivar end_step: The timestep when the session ended, if applicable.
:ivar local: Indicates if the session is local. Defaults to True.
"""
user: User
"""The user associated with this session."""
start_step: int
"""The timestep when the session was started."""
last_active_step: int
"""The last timestep when the session was active."""
end_step: Optional[int] = None
"""The timestep when the session ended, if applicable."""
local: bool = True
"""Indicates if the session is a local session or a remote session. Defaults to True as a local session."""
@classmethod
def create(cls, user: User, timestep: int) -> UserSession:
"""
Creates a new instance of UserSession.
This class method initialises a user session with the given user and timestep.
:param user: The user associated with this session.
:param timestep: The timestep when the session is created.
:return: An instance of UserSession.
"""
user.num_of_logins += 1
return UserSession(user=user, start_step=timestep, last_active_step=timestep)
def describe_state(self) -> Dict:
"""
Describes the current state of the user session.
:return: A dictionary representing the state of the user session.
"""
return self.model_dump()
class RemoteUserSession(UserSession):
"""
Represents a remote user session on the Node.
This class extends the UserSession class to include additional attributes and methods specific to remote sessions.
:ivar remote_ip_address: The IP address of the remote user.
:ivar local: Indicates that this is not a local session. Always set to False.
"""
remote_ip_address: IPV4Address
"""The IP address of the remote user."""
local: bool = False
"""Indicates that this is not a local session. Always set to False."""
@classmethod
def create(cls, user: User, timestep: int, remote_ip_address: IPV4Address) -> RemoteUserSession: # noqa
"""
Creates a new instance of RemoteUserSession.
This class method initialises a remote user session with the given user, timestep, and remote IP address.
:param user: The user associated with this session.
:param timestep: The timestep when the session is created.
:param remote_ip_address: The IP address of the remote user.
:return: An instance of RemoteUserSession.
"""
return RemoteUserSession(
user=user, start_step=timestep, last_active_step=timestep, remote_ip_address=remote_ip_address
)
def describe_state(self) -> Dict:
"""
Describes the current state of the remote user session.
This method extends the base describe_state method to include the remote IP address.
:return: A dictionary representing the state of the remote user session.
"""
state = super().describe_state()
state["remote_ip_address"] = str(self.remote_ip_address)
return state
class UserSessionManager(Service):
"""
Manages user sessions on a Node, including local and remote sessions.
This class handles authentication, session management, and session timeouts for users interacting with the Node.
"""
local_session: Optional[UserSession] = None
"""The current local user session, if any."""
remote_sessions: Dict[str, RemoteUserSession] = {}
"""A dictionary of active remote user sessions."""
historic_sessions: List[UserSession] = Field(default_factory=list)
"""A list of historic user sessions."""
local_session_timeout_steps: int = 30
"""The number of steps before a local session times out due to inactivity."""
remote_session_timeout_steps: int = 5
"""The number of steps before a remote session times out due to inactivity."""
max_remote_sessions: int = 3
"""The maximum number of concurrent remote sessions allowed."""
current_timestep: int = 0
"""The current timestep in the simulation."""
def __init__(self, **kwargs):
"""
Initializes a UserSessionManager instance.
:param username: The username for the default admin user
:param password: The password for the default admin user
"""
kwargs["name"] = "UserSessionManager"
kwargs["port"] = Port.NONE
kwargs["protocol"] = IPProtocol.NONE
super().__init__(**kwargs)
self.start()
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
# todo add doc about requeest schemas
rm.add_request(
"remote_login",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.remote_login(username=request[0], password=request[1], remote_ip_address=request[2])
)
),
)
rm.add_request(
"remote_logout",
RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.remote_logout(remote_session_id=request[0])
)
),
)
return rm
def show(self, markdown: bool = False, include_session_id: bool = False, include_historic: bool = False):
"""
Displays a table of the user sessions on the Node.
:param markdown: Whether to display the table in markdown format.
:param include_session_id: Whether to include session IDs in the table.
:param include_historic: Whether to include historic sessions in the table.
"""
headers = ["Session ID", "Username", "Type", "Remote IP", "Start Step", "Step Last Active", "End Step"]
if not include_session_id:
headers = headers[1:]
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.parent.hostname} User Sessions"
def _add_session_to_table(user_session: UserSession):
"""
Adds a user session to the table for display.
This helper function determines whether the session is local or remote and formats the session data
accordingly. It then adds the session data to the table.
:param user_session: The user session to add to the table.
"""
session_type = "local"
remote_ip = ""
if isinstance(user_session, RemoteUserSession):
session_type = "remote"
remote_ip = str(user_session.remote_ip_address)
data = [
user_session.uuid,
user_session.user.username,
session_type,
remote_ip,
user_session.start_step,
user_session.last_active_step,
user_session.end_step if user_session.end_step else "",
]
if not include_session_id:
data = data[1:]
table.add_row(data)
if self.local_session is not None:
_add_session_to_table(self.local_session)
for user_session in self.remote_sessions.values():
_add_session_to_table(user_session)
if include_historic:
for user_session in self.historic_sessions:
_add_session_to_table(user_session)
print(table.get_string(sortby="Step Last Active", reversesort=True))
def describe_state(self) -> Dict:
"""
Describes the current state of the UserSessionManager.
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["active_remote_logins"] = len(self.remote_sessions)
return state
@property
def _user_manager(self) -> UserManager:
"""
Returns the UserManager instance.
:return: The UserManager instance.
"""
return self.software_manager.software["UserManager"] # noqa
def pre_timestep(self, timestep: int) -> None:
"""Apply any pre-timestep logic that helps make sure we have the correct observations."""
self.current_timestep = timestep
if self.local_session:
if self.local_session.last_active_step + self.local_session_timeout_steps <= timestep:
self._timeout_session(self.local_session)
def _timeout_session(self, session: UserSession) -> None:
"""
Handles session timeout logic.
:param session: The session to be timed out.
"""
session.end_step = self.current_timestep
session_identity = session.user.username
if session.local:
self.local_session = None
session_type = "Local"
else:
self.remote_sessions.pop(session.uuid)
session_type = "Remote"
session_identity = f"{session_identity} {session.remote_ip_address}"
self.sys_log.info(f"{self.name}: {session_type} {session_identity} session timeout due to inactivity")
@property
def remote_session_limit_reached(self) -> bool:
"""
Checks if the maximum number of remote sessions has been reached.
:return: True if the limit is reached, otherwise False.
"""
return len(self.remote_sessions) >= self.max_remote_sessions
def validate_remote_session_uuid(self, remote_session_id: str) -> bool:
"""
Validates if a given remote session ID exists.
:param remote_session_id: The remote session ID to validate.
:return: True if the session ID exists, otherwise False.
"""
return remote_session_id in self.remote_sessions
def _login(
self, username: str, password: str, local: bool = True, remote_ip_address: Optional[IPv4Address] = None
) -> Optional[str]:
"""
Logs a user in either locally or remotely.
:param username: The username of the account.
:param password: The password of the account.
:param local: Whether the login is local or remote.
:param remote_ip_address: The remote IP address for remote login.
:return: The session ID if login is successful, otherwise None.
"""
if not self._can_perform_action():
return None
user = self._user_manager.authenticate_user(username=username, password=password)
if not user:
self.sys_log.info(f"{self.name}: Incorrect username or password")
return None
session_id = None
if local:
create_new_session = True
if self.local_session:
if self.local_session.user != user:
# logout the current user
self.local_logout()
else:
# not required as existing logged-in user attempting to re-login
create_new_session = False
if create_new_session:
self.local_session = UserSession.create(user=user, timestep=self.current_timestep)
session_id = self.local_session.uuid
else:
if not self.remote_session_limit_reached:
remote_session = RemoteUserSession.create(
user=user, timestep=self.current_timestep, remote_ip_address=remote_ip_address
)
session_id = remote_session.uuid
self.remote_sessions[session_id] = remote_session
self.sys_log.info(f"{self.name}: User {user.username} logged in")
return session_id
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Logs a user in locally.
:param username: The username of the account.
:param password: The password of the account.
:return: The session ID if login is successful, otherwise None.
"""
return self._login(username=username, password=password, local=True)
@validate_call()
def remote_login(self, username: str, password: str, remote_ip_address: IPV4Address) -> Optional[str]:
"""
Logs a user in remotely.
:param username: The username of the account.
:param password: The password of the account.
:param remote_ip_address: The remote IP address for the remote login.
:return: The session ID if login is successful, otherwise None.
"""
return self._login(username=username, password=password, local=False, remote_ip_address=remote_ip_address)
def _logout(self, local: bool = True, remote_session_id: Optional[str] = None) -> bool:
"""
Logs a user out either locally or remotely.
:param local: Whether the logout is local or remote.
:param remote_session_id: The remote session ID for remote logout.
:return: True if logout successful, otherwise False.
"""
if not self._can_perform_action():
return False
session = None
if local and self.local_session:
session = self.local_session
session.end_step = self.current_timestep
self.local_session = None
if not local and remote_session_id:
session = self.remote_sessions.pop(remote_session_id)
if session:
self.historic_sessions.append(session)
self.sys_log.info(f"{self.name}: User {session.user.username} logged out")
return True
return False
def local_logout(self) -> bool:
"""
Logs out the current local user.
:return: True if logout successful, otherwise False.
"""
return self._logout(local=True)
def remote_logout(self, remote_session_id: str) -> bool:
"""
Logs out a remote user by session ID.
:param remote_session_id: The remote session ID.
:return: True if logout successful, otherwise False.
"""
return self._logout(local=False, remote_session_id=remote_session_id)
@property
def local_user_logged_in(self) -> bool:
"""
Checks if a local user is currently logged in.
:return: True if a local user is logged in, otherwise False.
"""
return self.local_session is not None
class Node(SimComponent):
"""
A basic Node class that represents a node on the network.
@@ -861,11 +1501,14 @@ class Node(SimComponent):
red_scan_countdown: int = 0
"Time steps until reveal to red scan is complete."
SYSTEM_SOFTWARE: ClassVar[Dict[str, Type[Software]]] = {}
"Base system software that must be preinstalled."
def __init__(self, **kwargs):
"""
Initialize the Node with various components and managers.
This method initializes the ARP cache, ICMP handler, session manager, and software manager if they are not
This method initialises the ARP cache, ICMP handler, session manager, and software manager if they are not
provided.
"""
if not kwargs.get("sys_log"):
@@ -885,9 +1528,45 @@ class Node(SimComponent):
dns_server=kwargs.get("dns_server"),
)
super().__init__(**kwargs)
self._install_system_software()
self.session_manager.node = self
self.session_manager.software_manager = self.software_manager
self._install_system_software()
@property
def user_manager(self) -> Optional[UserManager]:
"""The Nodes User Manager."""
return self.software_manager.software.get("UserManager") # noqa
@property
def user_session_manager(self) -> Optional[UserSessionManager]:
"""The Nodes User Session Manager."""
return self.software_manager.software.get("UserSessionManager") # noqa
@property
def terminal(self) -> Optional[Terminal]:
"""The Nodes Terminal."""
return self.software_manager.software.get("Terminal")
def local_login(self, username: str, password: str) -> Optional[str]:
"""
Attempt to log in to the node uas a local user.
This method attempts to authenticate a local user with the given username and password. If successful, it
returns a session token. If authentication fails, it returns None.
:param username: The username of the account attempting to log in.
:param password: The password of the account attempting to log in.
:return: A session token if the login is successful, otherwise None.
"""
return self.user_session_manager.local_login(username, password)
def local_logout(self) -> None:
"""
Log out the current local user from the node.
This method ends the current local user's session and invalidates the session token.
"""
return self.user_session_manager.local_logout()
def ip_is_network_interface(self, ip_address: IPv4Address, enabled_only: bool = False) -> bool:
"""
@@ -942,7 +1621,7 @@ class Node(SimComponent):
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on."
return f"Cannot perform request on node '{self.node.hostname}' because it is not powered on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
@@ -984,7 +1663,7 @@ class Node(SimComponent):
application_name = request[0]
if self.software_manager.software.get(application_name):
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
return RequestResponse.from_bool(False)
return RequestResponse(status="success", data={"reason": "already installed"})
application_class = Application._application_registry[application_name]
self.software_manager.install(application_class)
application_instance = self.software_manager.software.get(application_name)
@@ -1091,10 +1770,6 @@ class Node(SimComponent):
return rm
def _install_system_software(self):
"""Install System Software - software that is usually provided with the OS."""
pass
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -1173,7 +1848,7 @@ class Node(SimComponent):
ip_address,
network_interface.speed,
"Enabled" if network_interface.enabled else "Disabled",
network_interface.nmne if primaite.simulator.network.nmne.CAPTURE_NMNE else "Disabled",
network_interface.nmne if network_interface.nmne_config.capture_nmne else "Disabled",
]
)
print(table)
@@ -1455,74 +2130,6 @@ class Node(SimComponent):
else:
return
def install_service(self, service: Service) -> None:
"""
Install a service on this node.
:param service: Service instance that has not been installed on any node yet.
:type service: Service
"""
if service in self:
_LOGGER.warning(f"Can't add service {service.name} to node {self.hostname}. It's already installed.")
return
self.services[service.uuid] = service
service.parent = self
service.install() # Perform any additional setup, such as creating files for this service on the node.
self.sys_log.info(f"Installed service {service.name}")
_LOGGER.debug(f"Added service {service.name} to node {self.hostname}")
self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager))
def uninstall_service(self, service: Service) -> None:
"""
Uninstall and completely remove service from this node.
:param service: Service object that is currently associated with this node.
:type service: Service
"""
if service not in self:
_LOGGER.warning(f"Can't remove service {service.name} from node {self.hostname}. It's not installed.")
return
service.uninstall() # Perform additional teardown, such as removing files or restarting the machine.
self.services.pop(service.uuid)
service.parent = None
self.sys_log.info(f"Uninstalled service {service.name}")
self._service_request_manager.remove_request(service.name)
def install_application(self, application: Application) -> None:
"""
Install an application on this node.
:param application: Application instance that has not been installed on any node yet.
:type application: Application
"""
if application in self:
_LOGGER.warning(
f"Can't add application {application.name} to node {self.hostname}. It's already installed."
)
return
self.applications[application.uuid] = application
application.parent = self
self.sys_log.info(f"Installed application {application.name}")
_LOGGER.debug(f"Added application {application.name} to node {self.hostname}")
self._application_request_manager.add_request(application.name, RequestType(func=application._request_manager))
def uninstall_application(self, application: Application) -> None:
"""
Uninstall and completely remove application from this node.
:param application: Application object that is currently associated with this node.
:type application: Application
"""
if application not in self:
_LOGGER.warning(
f"Can't remove application {application.name} from node {self.hostname}. It's not installed."
)
return
self.applications.pop(application.uuid)
application.parent = None
self.sys_log.info(f"Uninstalled application {application.name}")
self._application_request_manager.remove_request(application.name)
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node
@@ -1551,6 +2158,11 @@ class Node(SimComponent):
# for process_id in self.processes:
# self.processes[process_id]
def _install_system_software(self) -> None:
"""Preinstall required software."""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services

View File

@@ -5,7 +5,13 @@ from ipaddress import IPv4Address
from typing import Any, ClassVar, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node
from primaite.simulator.network.hardware.base import (
IPWiredNetworkInterface,
Link,
Node,
UserManager,
UserSessionManager,
)
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.application import ApplicationOperatingState
@@ -15,6 +21,7 @@ from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.utils.validators import IPV4Address
_LOGGER = getLogger(__name__)
@@ -292,6 +299,7 @@ class HostNode(Node):
* DNS (Domain Name System) Client: Resolves domain names to IP addresses.
* FTP (File Transfer Protocol) Client: Enables file transfers between the host and FTP servers.
* NTP (Network Time Protocol) Client: Synchronizes the system clock with NTP servers.
* Terminal Client: Handles SSH requests between HostNode and external components.
Applications:
------------
@@ -306,6 +314,9 @@ class HostNode(Node):
"NTPClient": NTPClient,
"WebBrowser": WebBrowser,
"NMAP": NMAP,
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
"""List of system software that is automatically installed on nodes."""
@@ -338,18 +349,6 @@ class HostNode(Node):
"""
return self.software_manager.software.get("ARP")
def _install_system_software(self):
"""
Installs the system software and network services typically found on an operating system.
This method equips the host with essential network services and applications, preparing it for various
network-related tasks and operations.
"""
for _, software_class in self.SYSTEM_SOFTWARE.items():
self.software_manager.install(software_class)
super()._install_system_software()
def default_gateway_hello(self):
"""
Sends a hello message to the default gateway to establish connectivity and resolve the gateway's MAC address.

View File

@@ -4,14 +4,14 @@ from __future__ import annotations
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from prettytable import MARKDOWN, PrettyTable
from pydantic import validate_call
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface
from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, UserManager, UserSessionManager
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.network_node import NetworkNode
from primaite.simulator.network.protocols.arp import ARPPacket
@@ -24,6 +24,7 @@ from primaite.simulator.system.core.session_manager import SessionManager
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.terminal.terminal import Terminal
from primaite.utils.validators import IPV4Address
@@ -1200,6 +1201,12 @@ class Router(NetworkNode):
RouteTable, RouterARP, and RouterICMP services.
"""
SYSTEM_SOFTWARE: ClassVar[Dict] = {
"UserSessionManager": UserSessionManager,
"UserManager": UserManager,
"Terminal": Terminal,
}
num_ports: int
network_interfaces: Dict[str, RouterInterface] = {}
"The Router Interfaces on the node."
@@ -1235,6 +1242,7 @@ class Router(NetworkNode):
resolution within the network. These services are crucial for the router's operation, enabling it to manage
network traffic efficiently.
"""
super()._install_system_software()
self.software_manager.install(RouterICMP)
icmp: RouterICMP = self.software_manager.icmp # noqa
icmp.router = self

View File

@@ -108,6 +108,9 @@ class Switch(NetworkNode):
for i in range(1, self.num_ports + 1):
self.connect_nic(SwitchPort())
def _install_system_software(self):
pass
def show(self, markdown: bool = False):
"""
Prints a table of the SwitchPorts on the Switch.

View File

@@ -1,48 +1,25 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict, Final, List
from typing import List
CAPTURE_NMNE: bool = True
"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True."""
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination IP address."""
CAPTURE_BY_PROTOCOL: Final[bool] = False
"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP)."""
CAPTURE_BY_PORT: Final[bool] = False
"""Flag to determine if captures should be organized by source or destination port."""
CAPTURE_BY_KEYWORD: Final[bool] = False
"""Flag to determine if captures should be filtered and categorised based on specific keywords."""
from pydantic import BaseModel, ConfigDict
def set_nmne_config(nmne_config: Dict):
"""
Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary.
class NMNEConfig(BaseModel):
"""Store all the information to perform NMNE operations."""
This function updates global settings related to NMNE capture, including whether to capture NMNEs and what
keywords to use for identifying NMNEs.
model_config = ConfigDict(extra="forbid")
The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary,
and maintains type integrity by checking the types of the provided values.
:param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include:
"capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings)
to specify keywords for NMNE identification.
"""
global NMNE_CAPTURE_KEYWORDS
global CAPTURE_NMNE
# Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect
CAPTURE_NMNE = nmne_config.get("capture_nmne", False)
if not isinstance(CAPTURE_NMNE, bool):
CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean
# Update the NMNE capture keywords, appending new keywords if provided
NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", [])
if not isinstance(NMNE_CAPTURE_KEYWORDS, list):
NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list
capture_nmne: bool = False
"""Indicates whether Malicious Network Events (MNEs) should be captured."""
nmne_capture_keywords: List[str] = []
"""List of keywords to identify malicious network events."""
capture_by_direction: bool = True
"""Captures should be organized by traffic direction (inbound/outbound)."""
capture_by_ip_address: bool = False
"""Captures should be organized by source or destination IP address."""
capture_by_protocol: bool = False
"""Captures should be organized by network protocol (e.g., TCP, UDP)."""
capture_by_port: bool = False
"""Captures should be organized by source or destination port."""
capture_by_keyword: bool = False
"""Captures should be filtered and categorised based on specific keywords."""

View File

@@ -4,7 +4,7 @@ from enum import Enum
from typing import Union
from pydantic import BaseModel, field_validator, validate_call
from pydantic_core.core_schema import FieldValidationInfo
from pydantic_core.core_schema import ValidationInfo
from primaite import getLogger
@@ -96,7 +96,7 @@ class ICMPPacket(BaseModel):
@field_validator("icmp_code") # noqa
@classmethod
def _icmp_type_must_have_icmp_code(cls, v: int, info: FieldValidationInfo) -> int:
def _icmp_type_must_have_icmp_code(cls, v: int, info: ValidationInfo) -> int:
"""Validates the icmp_type and icmp_code."""
icmp_type = info.data["icmp_type"]
if get_icmp_type_code_description(icmp_type, v):

View File

@@ -0,0 +1,89 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from enum import IntEnum
from typing import Optional
from primaite.interface.request import RequestResponse
from primaite.simulator.network.protocols.packet import DataPacket
class SSHTransportMessage(IntEnum):
"""
Enum list of Transport layer messages that can be handled by the simulation.
Each msg value is equivalent to the real-world.
"""
SSH_MSG_USERAUTH_REQUEST = 50
"""Requests User Authentication."""
SSH_MSG_USERAUTH_FAILURE = 51
"""Indicates User Authentication failed."""
SSH_MSG_USERAUTH_SUCCESS = 52
"""Indicates User Authentication was successful."""
SSH_MSG_SERVICE_REQUEST = 24
"""Requests a service - such as executing a command."""
# These two msgs are invented for primAITE however are modelled on reality
SSH_MSG_SERVICE_FAILED = 25
"""Indicates that the requested service failed."""
SSH_MSG_SERVICE_SUCCESS = 26
"""Indicates that the requested service was successful."""
class SSHConnectionMessage(IntEnum):
"""Int Enum list of all SSH's connection protocol messages that can be handled by the simulation."""
SSH_MSG_CHANNEL_OPEN = 80
"""Requests an open channel - Used in combination with SSH_MSG_USERAUTH_REQUEST."""
SSH_MSG_CHANNEL_OPEN_CONFIRMATION = 81
"""Confirms an open channel."""
SSH_MSG_CHANNEL_OPEN_FAILED = 82
"""Indicates that channel opening failure."""
SSH_MSG_CHANNEL_DATA = 84
"""Indicates that data is being sent through the channel."""
SSH_MSG_CHANNEL_CLOSE = 87
"""Closes the channel."""
class SSHUserCredentials(DataPacket):
"""Hold Username and Password in SSH Packets."""
username: str
"""Username for login"""
password: str
"""Password for login"""
class SSHPacket(DataPacket):
"""Represents an SSHPacket."""
transport_message: SSHTransportMessage
"""Message Transport Type"""
connection_message: SSHConnectionMessage
"""Message Connection Status"""
user_account: Optional[SSHUserCredentials] = None
"""User Account Credentials if passed"""
connection_request_uuid: Optional[str] = None
"""Connection Request UUID used when establishing a remote connection"""
connection_uuid: Optional[str] = None
"""Connection UUID used when validating a remote connection"""
ssh_output: Optional[RequestResponse] = None
"""RequestResponse from Request Manager"""
ssh_command: Optional[list] = None
"""Request String"""

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from uuid import uuid4
from prettytable import MARKDOWN, PrettyTable
@@ -54,6 +54,12 @@ class DatabaseClientConnection(BaseModel):
if self.client and self.is_active:
self.client._disconnect(self.connection_id) # noqa
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_id}', is_active={self.is_active})"
def __repr__(self) -> str:
return str(self)
class DatabaseClient(Application, identifier="DatabaseClient"):
"""
@@ -76,7 +82,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""Connection ID to the Database Server."""
client_connections: Dict[str, DatabaseClientConnection] = {}
"""Keep track of active connections to Database Server."""
_client_connection_requests: Dict[str, Optional[str]] = {}
_client_connection_requests: Dict[str, Optional[Union[str, DatabaseClientConnection]]] = {}
"""Dictionary of connection requests to Database Server."""
connected: bool = False
"""Boolean Value for whether connected to DB Server."""
@@ -187,7 +193,7 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
return False
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
def _check_client_connection(self, connection_id: str) -> bool:
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
@@ -211,23 +217,30 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
:type: is_reattempt: Optional[bool]
"""
if is_reattempt:
valid_connection = self._check_client_connection(connection_id=connection_request_id)
if valid_connection:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
database_client_connection = self._client_connection_requests.pop(connection_request_id)
self.sys_log.info(
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
f"Connection Request ID was {connection_request_id}."
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
if isinstance(database_client_connection, DatabaseClientConnection):
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} authorised. "
f"Using connection id {database_client_connection}"
)
self.connected = True
self._last_connection_successful = True
return database_client_connection
else:
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined"
)
self._last_connection_successful = False
return None
else:
self.sys_log.warning(
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
f"Connection Request ID was {connection_request_id}."
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) to {server_ip_address} declined "
f"due to unknown client-side connection request id"
)
self._last_connection_successful = False
return None
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
@@ -300,9 +313,14 @@ class DatabaseClient(Application, identifier="DatabaseClient"):
"""
if not self._can_perform_action():
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
self.sys_log.info(
f"{self.name}: Sending new connection request ({connection_request_id}) to {self.server_ip_address}"
)
return self._connect(
server_ip_address=self.server_ip_address,
password=self.server_password,

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator.core import RequestType
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -20,9 +21,7 @@ if TYPE_CHECKING:
from primaite.simulator.system.services.arp.arp import ARP
from primaite.simulator.system.services.icmp.icmp import ICMP
from typing import Type, TypeVar
IOSoftwareClass = TypeVar("IOSoftwareClass", bound=IOSoftware)
from typing import Type
class SoftwareManager:
@@ -51,7 +50,7 @@ class SoftwareManager:
self.node = parent_node
self.session_manager = session_manager
self.software: Dict[str, Union[Service, Application]] = {}
self._software_class_to_name_map: Dict[Type[IOSoftwareClass], str] = {}
self._software_class_to_name_map: Dict[Type[IOSoftware], str] = {}
self.port_protocol_mapping: Dict[Tuple[Port, IPProtocol], Union[Service, Application]] = {}
self.sys_log: SysLog = sys_log
self.file_system: FileSystem = file_system
@@ -104,33 +103,38 @@ class SoftwareManager:
return True
return False
def install(self, software_class: Type[IOSoftwareClass]):
def install(self, software_class: Type[IOSoftware], **install_kwargs):
"""
Install an Application or Service.
:param software_class: The software class.
"""
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
# separation of concerns.
if software_class in self._software_class_to_name_map:
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
return
software = software_class(
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
software_manager=self,
sys_log=self.sys_log,
file_system=self.file_system,
dns_server=self.dns_server,
**install_kwargs,
)
software.parent = self.node
if isinstance(software, Application):
software.install()
self.node.applications[software.uuid] = software
self.node._application_request_manager.add_request(
software.name, RequestType(func=software._request_manager)
)
elif isinstance(software, Service):
self.node.services[software.uuid] = software
self.node._service_request_manager.add_request(software.name, RequestType(func=software._request_manager))
software.install()
software.software_manager = self
self.software[software.name] = software
self.port_protocol_mapping[(software.port, software.protocol)] = software
if isinstance(software, Application):
software.operating_state = ApplicationOperatingState.CLOSED
# add the software to the node's registry after it has been fully initialized
if isinstance(software, Service):
self.node.install_service(software)
elif isinstance(software, Application):
self.node.install_application(software)
self.node.sys_log.info(f"Installed {software.name}")
def uninstall(self, software_name: str):
"""
@@ -138,25 +142,31 @@ class SoftwareManager:
:param software_name: The software name.
"""
if software_name in self.software:
self.software[software_name].uninstall()
software = self.software.pop(software_name) # noqa
if isinstance(software, Application):
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Uninstalled {software_name}")
if software_name not in self.software:
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")
self.software[software_name].uninstall()
software = self.software.pop(software_name) # noqa
if isinstance(software, Application):
self.node.applications.pop(software.uuid)
self.node._application_request_manager.remove_request(software.name)
elif isinstance(software, Service):
self.node.services.pop(software.uuid)
software.uninstall()
self.node._service_request_manager.remove_request(software.name)
software.parent = None
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Uninstalled {software_name}")
return
def send_internal_payload(self, target_software: str, payload: Any):
"""

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -191,12 +191,16 @@ class DatabaseService(Service):
:return: Response to connection request containing success info.
:rtype: Dict[str, Union[int, Dict[str, bool]]]
"""
self.sys_log.info(f"{self.name}: Processing new connection request ({connection_request_id}) from {src_ip}")
status_code = 500 # Default internal server error
connection_id = None
if self.operating_state == ServiceOperatingState.RUNNING:
status_code = 503 # service unavailable
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, service is at "
f"capacity."
)
if self.health_state_actual in [
SoftwareHealthState.GOOD,
SoftwareHealthState.FIXING,
@@ -208,12 +212,16 @@ class DatabaseService(Service):
# try to create connection
if not self.add_connection(connection_id=connection_id, session_id=session_id):
status_code = 500
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
else:
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} declined, "
f"returning status code 500"
)
else:
status_code = 401 # Unauthorised
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
self.sys_log.info(
f"{self.name}: Connection request ({connection_request_id}) from {src_ip} unauthorised "
f"(incorrect password), returning status code 401"
)
else:
status_code = 404 # service not found
return {

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,523 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
SSHTransportMessage,
SSHUserCredentials,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.service import Service, ServiceOperatingState
class TerminalClientConnection(BaseModel):
"""
TerminalClientConnection Class.
This class is used to record current User Connections to the Terminal class.
"""
parent_terminal: Terminal
"""The parent Node that this connection was created on."""
session_id: str = None
"""Session ID that connection is linked to"""
connection_uuid: str = None
"""Connection UUID"""
connection_request_id: str = None
"""Connection request ID"""
time: datetime = None
"""Timestamp connection was created."""
ip_address: IPv4Address
"""Source IP of Connection"""
is_active: bool = True
"""Flag to state whether the connection is active or not"""
def __str__(self) -> str:
return f"{self.__class__.__name__}(connection_id='{self.connection_uuid}')"
def __repr__(self) -> str:
return self.__str__()
def __getitem__(self, key: Any) -> Any:
return getattr(self, key)
@property
def client(self) -> Optional[Terminal]:
"""The Terminal that holds this connection."""
return self.parent_terminal
def disconnect(self) -> bool:
"""Disconnect the session."""
return self.parent_terminal._disconnect(connection_uuid=self.connection_uuid)
@abstractmethod
def execute(self, command: Any) -> bool:
"""Execute a given command."""
pass
class LocalTerminalConnection(TerminalClientConnection):
"""
LocalTerminalConnectionClass.
This class represents a local terminal when connected.
"""
ip_address: str = "Local Connection"
def execute(self, command: Any) -> Optional[RequestResponse]:
"""Execute a given command on local Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return None
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return None
return self.parent_terminal.execute(command, connection_id=self.connection_uuid)
class RemoteTerminalConnection(TerminalClientConnection):
"""
RemoteTerminalConnection Class.
This class acts as broker between the terminal and remote.
"""
def execute(self, command: Any) -> bool:
"""Execute a given command on the remote Terminal."""
if self.parent_terminal.operating_state != ServiceOperatingState.RUNNING:
self.parent_terminal.sys_log.warning("Cannot process command as system not running")
return False
if not self.is_active:
self.parent_terminal.sys_log.warning("Connection inactive, cannot execute")
return False
# Send command to remote terminal to process.
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_SERVICE_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload: SSHPacket = SSHPacket(
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=self.connection_request_id,
connection_uuid=self.connection_uuid,
ssh_command=command,
)
return self.parent_terminal.send(payload=payload, session_id=self.session_id)
class Terminal(Service):
"""Class used to simulate a generic terminal service. Can be interacted with by other terminals via SSH."""
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
def __init__(self, **kwargs):
kwargs["name"] = "Terminal"
kwargs["port"] = Port.SSH
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
state = super().describe_state()
return state
def show(self, markdown: bool = False):
"""
Display the remote connections to this terminal instance in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
self.show_connections(markdown=markdown)
def _init_request_manager(self) -> RequestManager:
"""Initialise Request manager."""
rm = super()._init_request_manager()
rm.add_request(
"send",
request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.send())),
)
def _login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._process_local_login(username=request[0], password=request[1])
if login:
return RequestResponse(status="success", data={})
else:
return RequestResponse(status="failure", data={})
def _remote_login(request: RequestFormat, context: Dict) -> RequestResponse:
login = self._send_remote_login(username=request[0], password=request[1], ip_address=request[2])
if login:
return RequestResponse(status="success", data={})
else:
return RequestResponse(status="failure", data={})
def _execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
"""Execute an instruction."""
command: str = request[0]
connection_id: str = request[1]
self.execute(command, connection_id=connection_id)
return RequestResponse(status="success", data={})
def _logoff(request: RequestFormat, context: Dict) -> RequestResponse:
"""Logoff from connection."""
connection_uuid = request[0]
# TODO: Uncomment this when UserSessionManager merged.
# self.parent.UserSessionManager.logoff(connection_uuid)
self._disconnect(connection_uuid)
return RequestResponse(status="success", data={})
rm.add_request(
"Login",
request_type=RequestType(func=_login),
)
rm.add_request(
"Remote Login",
request_type=RequestType(func=_remote_login),
)
rm.add_request(
"Execute",
request_type=RequestType(func=_execute_request),
)
rm.add_request("Logoff", request_type=RequestType(func=_logoff))
return rm
def execute(self, command: List[Any], connection_id: str) -> Optional[RequestResponse]:
"""Execute a passed ssh command via the request manager."""
valid_connection = self._check_client_connection(connection_id=connection_id)
if valid_connection:
return self.parent.apply_request(command)
else:
self.sys_log.error("Invalid connection ID provided")
return None
def _create_local_connection(self, connection_uuid: str, session_id: str) -> TerminalClientConnection:
"""Create a new connection object and amend to list of active connections.
:param connection_uuid: Connection ID of the new local connection
:param session_id: Session ID of the new local connection
:return: TerminalClientConnection object
"""
new_connection = LocalTerminalConnection(
parent_terminal=self,
connection_uuid=connection_uuid,
session_id=session_id,
time=datetime.now(),
)
self._connections[connection_uuid] = new_connection
self._client_connection_requests[connection_uuid] = new_connection
return new_connection
def login(
self, username: str, password: str, ip_address: Optional[IPv4Address] = None
) -> Optional[TerminalClientConnection]:
"""Login to the terminal. Will attempt a remote login if ip_address is given, else local.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt. If None, login is assumed local.
:type: ip_address: Optional[IPv4Address]
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning("Cannot login as service is not running.")
return None
connection_request_id = str(uuid4())
self._client_connection_requests[connection_request_id] = None
if ip_address:
# Assuming that if IP is passed we are connecting to remote
return self._send_remote_login(
username=username, password=password, ip_address=ip_address, connection_request_id=connection_request_id
)
else:
return self._process_local_login(username=username, password=password)
def _process_local_login(self, username: str, password: str) -> Optional[TerminalClientConnection]:
"""Local session login to terminal.
:param username: Username for login.
:param password: Password for login.
:return: boolean, True if successful, else False
"""
# TODO: Un-comment this when UserSessionManager is merged.
# connection_uuid = self.parent.UserSessionManager.login(username=username, password=password)
connection_uuid = str(uuid4())
if connection_uuid:
self.sys_log.info(f"Login request authorised, connection uuid: {connection_uuid}")
# Add new local session to list of connections and return
return self._create_local_connection(connection_uuid=connection_uuid, session_id="Local_Connection")
else:
self.sys_log.warning("Login failed, incorrect Username or Password")
return None
def _validate_client_connection_request(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._client_connection_requests else False
def _check_client_connection(self, connection_id: str) -> bool:
"""Check that client_connection_id is valid."""
return True if connection_id in self._connections else False
def _send_remote_login(
self,
username: str,
password: str,
ip_address: IPv4Address,
connection_request_id: str,
is_reattempt: bool = False,
) -> Optional[RemoteTerminalConnection]:
"""Send a remote login attempt and connect to Node.
:param: username: Username used to connect to the remote node.
:type: username: str
:param: password: Password used to connect to the remote node
:type: password: str
:param: ip_address: Target Node IP address for login attempt.
:type: ip_address: IPv4Address
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: is_reattempt: True if the request has been reattempted. Default False.
:type: is_reattempt: Optional[bool]
:return: RemoteTerminalConnection: Connection Object for sending further commands if successful, else False.
"""
self.sys_log.info(f"Sending Remote login attempt to {ip_address}. Connection_id is {connection_request_id}")
if is_reattempt:
valid_connection_request = self._validate_client_connection_request(connection_id=connection_request_id)
if valid_connection_request:
remote_terminal_connection = self._client_connection_requests.pop(connection_request_id)
if isinstance(remote_terminal_connection, RemoteTerminalConnection):
self.sys_log.info(f"{self.name}: Remote Connection to {ip_address} authorised.")
return remote_terminal_connection
else:
self.sys_log.warning(f"Connection request{connection_request_id} declined")
return None
else:
self.sys_log.warning(f"{self.name}: Remote connection to {ip_address} declined.")
return None
transport_message: SSHTransportMessage = SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST
connection_message: SSHConnectionMessage = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
user_details: SSHUserCredentials = SSHUserCredentials(username=username, password=password)
payload_contents = {
"type": "login_request",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
}
payload: SSHPacket = SSHPacket(
payload=payload_contents,
transport_message=transport_message,
connection_message=connection_message,
user_account=user_details,
connection_request_uuid=connection_request_id,
)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=ip_address, dest_port=self.port
)
return self._send_remote_login(
username=username,
password=password,
ip_address=ip_address,
is_reattempt=True,
connection_request_id=connection_request_id,
)
def _create_remote_connection(
self, connection_id: str, connection_request_id: str, session_id: str, source_ip: str
) -> None:
"""Create a new TerminalClientConnection Object.
:param: connection_request_id: Connection Request ID
:type: connection_request_id: str
:param: session_id: Session ID of connection.
:type: session_id: str
"""
client_connection = RemoteTerminalConnection(
parent_terminal=self,
session_id=session_id,
connection_uuid=connection_id,
ip_address=source_ip,
connection_request_id=connection_request_id,
time=datetime.now(),
)
self._connections[connection_id] = client_connection
self._client_connection_requests[connection_request_id] = client_connection
def receive(self, session_id: str, payload: Union[SSHPacket, Dict], **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
:param payload: A payload to receive.
:param session_id: The session id the payload relates to.
:return: True.
"""
source_ip = kwargs["from_network_interface"].ip_address
self.sys_log.info(f"Received payload: {payload}. Source: {source_ip}")
if isinstance(payload, SSHPacket):
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
# validate & add connection
# TODO: uncomment this as part of 2781
# connection_id = self.parent.UserSessionManager.login(username=username, password=password)
connection_id = str(uuid4())
if connection_id:
connection_request_id = payload.connection_request_uuid
username = payload.user_account.username
password = payload.user_account.password
print(f"Connection ID is: {connection_request_id}")
self.sys_log.info(f"Connection authorised, session_id: {session_id}")
self._create_remote_connection(
connection_id=connection_id,
connection_request_id=connection_request_id,
session_id=session_id,
source_ip=source_ip,
)
transport_message = SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS
connection_message = SSHConnectionMessage.SSH_MSG_CHANNEL_DATA
payload_contents = {
"type": "login_success",
"username": username,
"password": password,
"connection_request_id": connection_request_id,
"connection_id": connection_id,
}
payload: SSHPacket = SSHPacket(
payload=payload_contents,
transport_message=transport_message,
connection_message=connection_message,
connection_request_uuid=connection_request_id,
connection_uuid=connection_id,
)
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload=payload, dest_port=self.port, session_id=session_id
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_SUCCESS:
self.sys_log.info("Login Successful")
self._create_remote_connection(
connection_id=payload.connection_uuid,
connection_request_id=payload.connection_request_uuid,
session_id=session_id,
source_ip=source_ip,
)
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
# Requesting a command to be executed
self.sys_log.info("Received command to execute")
command = payload.ssh_command
valid_connection = self._check_client_connection(payload.connection_uuid)
self.sys_log.info(f"Connection uuid is {valid_connection}")
if valid_connection:
return self.execute(command, payload.connection_uuid)
else:
self.sys_log.error(f"Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command.")
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "disconnect":
connection_id = payload["connection_id"]
valid_id = self._check_client_connection(connection_id)
if valid_id:
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from remote.")
self._disconnect(payload["connection_id"])
else:
self.sys_log.info("No Active connection held for received connection ID.")
return True
def _disconnect(self, connection_uuid: str) -> bool:
"""Disconnect from the remote.
:param connection_uuid: Connection ID that we want to disconnect.
:return True if successful, False otherwise.
"""
if not self._connections:
self.sys_log.warning("No remote connection present")
return False
connection = self._connections.pop(connection_uuid)
connection.is_active = False
if isinstance(connection, RemoteTerminalConnection):
# Send disconnect command via software manager
session_id = connection.session_id
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect", "connection_id": connection_uuid},
dest_port=self.port,
session_id=session_id,
)
self.sys_log.info(f"{self.name}: Disconnected {connection_uuid}")
return True
elif isinstance(connection, LocalTerminalConnection):
# No further action needed
return True
def send(
self, payload: SSHPacket, dest_ip_address: Optional[IPv4Address] = None, session_id: Optional[str] = None
) -> bool:
"""
Send a payload out from the Terminal.
:param payload: The payload to be sent.
:param dest_up_address: The IP address of the payload destination.
"""
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.warning(f"Cannot send commands when Operating state is {self.operating_state}!")
return False
self.sys_log.debug(f"Sending payload: {payload}")
return super().send(
payload=payload, dest_ip_address=dest_ip_address, dest_port=self.port, session_id=session_id
)

View File

@@ -291,7 +291,7 @@ class IOSoftware(Software):
"""
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
self.software_manager.node.sys_log.error(
f"{self.name} Error: {self.software_manager.node.hostname} is not online."
f"{self.name} Error: {self.software_manager.node.hostname} is not powered on."
)
return False
return True
@@ -313,7 +313,7 @@ class IOSoftware(Software):
# if over or at capacity, set to overwhelmed
if len(self._connections) >= self.max_sessions:
self.set_health_state(SoftwareHealthState.OVERWHELMED)
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
self.sys_log.warning(f"{self.name}: Connection request ({connection_id}) declined. Service is at capacity.")
return False
else:
# if service was previously overwhelmed, set to good because there is enough space for connections
@@ -330,11 +330,11 @@ class IOSoftware(Software):
"ip_address": session_details.with_ip_address if session_details else None,
"time": datetime.now(),
}
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
self.sys_log.info(f"{self.name}: Connection request ({connection_id}) authorised")
return True
# connection with given id already exists
self.sys_log.warning(
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
f"{self.name}: Connection request ({connection_id}) declined. Connection already exists."
)
return False

View File

@@ -99,7 +99,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -0,0 +1,34 @@
io_settings:
save_step_metadata: false
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
agent_log_level: INFO
save_agent_logs: true
write_agent_log_to_terminal: True
game:
max_episode_length: 256
ports:
- ARP
protocols:
- ICMP
- UDP
simulation:
network:
nodes:
- hostname: client_1
type: computer
ip_address: 192.168.10.11
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
users:
- username: jane.doe
password: '1234'
is_admin: true
- username: john.doe
password: password_1
is_admin: false

View File

@@ -92,7 +92,7 @@ agents:
- NONE
tcp:
- DNS
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -111,7 +111,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -68,7 +68,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -89,7 +89,7 @@ agents:
- NONE
tcp:
- DNS
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -44,7 +44,7 @@ agents:
num_files: 1
num_nics: 1
include_num_access: false
include_nmne: true
include_nmne: false
- type: LINKS
label: LINKS

View File

@@ -89,7 +89,7 @@ agents:
- NONE
tcp:
- DNS
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -120,7 +120,7 @@ agents:
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
include_nmne: false
routers:
- hostname: router_1
num_ports: 0

View File

@@ -30,21 +30,21 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
rayinit(local_mode=True)
rayinit()
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
class TestService(Service):
class DummyService(Service):
"""Test Service class"""
def describe_state(self) -> Dict:
return super().describe_state()
def __init__(self, **kwargs):
kwargs["name"] = "TestService"
kwargs["name"] = "DummyService"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -75,15 +75,15 @@ def uc2_network() -> Network:
@pytest.fixture(scope="function")
def service(file_system) -> TestService:
return TestService(
name="TestService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service")
def service(file_system) -> DummyService:
return DummyService(
name="DummyService", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_service")
)
@pytest.fixture(scope="function")
def service_class():
return TestService
return DummyService
@pytest.fixture(scope="function")

View File

@@ -1,6 +1,7 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict
import pytest
import yaml
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
@@ -100,6 +101,7 @@ def test_ray_single_agent_action_masking(monkeypatch):
monkeypatch.undo()
@pytest.mark.xfail(reason="Fails due to being flaky when run in CI.")
def test_ray_multi_agent_action_masking(monkeypatch):
"""Check that Ray agents never take invalid actions when using MARL."""
with open(MARL_PATH, "r") as f:

View File

@@ -22,8 +22,7 @@ def test_passing_actions_down(monkeypatch) -> None:
for n in [pc1, pc2, srv, s1]:
sim.network.add_node(n)
database_service = DatabaseService(file_system=srv.file_system)
srv.install_service(database_service)
srv.software_manager.install(DatabaseService)
downloads_folder = pc1.file_system.create_folder("downloads")
pc1.file_system.create_file("bermuda_triangle.png", folder_name="downloads")

View File

@@ -45,7 +45,7 @@ def test_fix_duration_set_from_config():
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
for service in SERVICE_TYPES_MAPPING:
for service in ["DNSClient", "DNSServer", "DatabaseService", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 3
@@ -53,7 +53,7 @@ def test_fix_duration_set_from_config():
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
for application in applications:
for application in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot", "DatabaseClient"]:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
@@ -64,17 +64,13 @@ def test_fix_duration_for_one_item():
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
# in config - services take 3 timesteps to fix
services = copy.copy(SERVICE_TYPES_MAPPING)
services.pop("DatabaseService")
for service in services:
for service in ["DNSClient", "DNSServer", "WebServer", "FTPClient", "FTPServer", "NTPServer"]:
assert client_1.software_manager.software.get(service) is not None
assert client_1.software_manager.software.get(service).fixing_duration == 2
# in config - applications take 1 timestep to fix
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
applications.remove("DatabaseClient")
for applications in applications:
for applications in ["RansomwareScript", "WebBrowser", "DataManipulationBot", "DoSBot"]:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2

View File

@@ -9,9 +9,11 @@ from gymnasium import spaces
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import NetworkInterface
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
@@ -75,6 +77,18 @@ def test_nic(simulation):
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
nmne_config = {
"capture_nmne": True, # Enable the capture of MNEs
"nmne_capture_keywords": [
"DELETE",
"ENCRYPT",
], # Specify "DELETE/ENCRYPT" SQL command as a keyword for MNE detection
}
# Apply the NMNE configuration settings
NetworkInterface.nmne_config = NMNEConfig(**nmne_config)
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
@@ -144,7 +158,7 @@ def test_nic_monitored_traffic(simulation):
pc2: Computer = simulation.network.get_node_by_hostname("client_2")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True, monitored_traffic=monitored_traffic
where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=False, monitored_traffic=monitored_traffic
)
simulation.pre_timestep(0) # apply timestep to whole sim

View File

@@ -0,0 +1,50 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from pprint import pprint
import pytest
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import AgentHistoryItem
from primaite.session.environment import PrimaiteGymEnv
@pytest.fixture()
def create_env():
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(env_config=cfg)
return env
def test_rng_seed_set(create_env):
"""Test with RNG seed set."""
env = create_env
env.reset(seed=3)
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
env.reset(seed=3)
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a == b
def test_rng_seed_unset(create_env):
"""Test with no RNG seed."""
env = create_env
env.reset()
for i in range(100):
env.step(0)
a = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
env.reset()
for i in range(100):
env.step(0)
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "DONOTHING"]
assert a != b

View File

@@ -1,12 +1,14 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.network.nmne import NMNEConfig
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
def test_capture_nmne(uc2_network):
def test_capture_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured.
@@ -33,7 +35,7 @@ def test_capture_nmne(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
assert web_server_nic.nmne == {}
@@ -82,7 +84,7 @@ def test_capture_nmne(uc2_network):
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
def test_describe_state_nmne(uc2_network):
def test_describe_state_nmne(uc2_network: Network):
"""
Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state.
@@ -110,7 +112,7 @@ def test_describe_state_nmne(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Assert that initially, there are no captured MNEs on both web and database servers
web_server_nic_state = web_server_nic.describe_state()
@@ -190,7 +192,7 @@ def test_describe_state_nmne(uc2_network):
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 4}}}}
def test_capture_nmne_observations(uc2_network):
def test_capture_nmne_observations(uc2_network: Network):
"""
Tests the NICObservation class's functionality within a simulated network environment.
@@ -219,7 +221,7 @@ def test_capture_nmne_observations(uc2_network):
}
# Apply the NMNE configuration settings
set_nmne_config(nmne_config)
NIC.nmne_config = NMNEConfig(**nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)

View File

@@ -0,0 +1,26 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.base import UserManager
from tests import TEST_ASSETS_ROOT
def test_users_from_config():
config_path = TEST_ASSETS_ROOT / "configs" / "basic_node_with_users.yaml"
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
client_1 = network.get_node_by_hostname("client_1")
user_manager: UserManager = client_1.software_manager.software["UserManager"]
assert len(user_manager.users) == 3
assert user_manager.users["jane.doe"].password == "1234"
assert user_manager.users["jane.doe"].is_admin
assert user_manager.users["john.doe"].password == "password_1"
assert not user_manager.users["john.doe"].is_admin

View File

@@ -23,7 +23,7 @@ def populated_node(
server.power_on()
server.software_manager.install(service_class)
service = server.software_manager.software.get("TestService")
service = server.software_manager.software.get("DummyService")
service.start()
return server, service
@@ -42,7 +42,7 @@ def test_service_on_offline_node(service_class):
computer.power_on()
computer.software_manager.install(service_class)
service: Service = computer.software_manager.software.get("TestService")
service: Service = computer.software_manager.software.get("DummyService")
computer.power_off()

View File

@@ -0,0 +1,274 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
from uuid import uuid4
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import User
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
@pytest.fixture(scope="function")
def client_server_network() -> Tuple[Computer, Server, Network]:
network = Network()
client = Computer(
hostname="client",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
client.power_on()
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
network.connect(client.network_interface[1], server.network_interface[1])
return client, server, network
def test_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
def test_login_count_increases(client_server_network):
client, server, network = client_server_network
admin_user: User = client.user_manager.users["admin"]
assert admin_user.num_of_logins == 0
client.user_session_manager.local_login(username="admin", password="admin")
assert admin_user.num_of_logins == 1
client.user_session_manager.local_login(username="admin", password="admin")
# shouldn't change as user is already logged in
assert admin_user.num_of_logins == 1
client.user_session_manager.local_logout()
client.user_session_manager.local_login(username="admin", password="admin")
assert admin_user.num_of_logins == 2
def test_local_login_failure(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert not client.user_session_manager.local_user_logged_in
def test_new_user_local_login_success(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
client.user_manager.add_user(username="jane.doe", password="12345")
client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
def test_new_local_login_clears_previous_login(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
client.user_manager.add_user(username="jane.doe", password="12345")
new_session_id = client.user_session_manager.local_login(username="jane.doe", password="12345")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "jane.doe"
assert new_session_id != current_session_id
def test_new_local_login_attempt_same_uses_persists(client_server_network):
client, server, network = client_server_network
assert not client.user_session_manager.local_user_logged_in
current_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
new_session_id = client.user_session_manager.local_login(username="admin", password="admin")
assert client.user_session_manager.local_user_logged_in
assert client.user_session_manager.local_session.user.username == "admin"
assert new_session_id == current_session_id
def test_remote_login_success(client_server_network):
# partial test for now until we get the terminal application in so that amn actual remote connection can be made
client, server, network = client_server_network
assert not server.user_session_manager.remote_sessions
remote_session_id = server.user_session_manager.remote_login(
username="admin", password="admin", remote_ip_address="192.168.1.10"
)
assert server.user_session_manager.validate_remote_session_uuid(remote_session_id)
server.user_session_manager.remote_logout(remote_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id)
def test_remote_login_failure(client_server_network):
# partial test for now until we get the terminal application in so that amn actual remote connection can be made
client, server, network = client_server_network
assert not server.user_session_manager.remote_sessions
remote_session_id = server.user_session_manager.remote_login(
username="jane.doe", password="12345", remote_ip_address="192.168.1.10"
)
assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id)
def test_new_user_remote_login_success(client_server_network):
client, server, network = client_server_network
server.user_manager.add_user(username="jane.doe", password="12345")
remote_session_id = server.user_session_manager.remote_login(
username="jane.doe", password="12345", remote_ip_address="192.168.1.10"
)
assert server.user_session_manager.validate_remote_session_uuid(remote_session_id)
server.user_session_manager.remote_logout(remote_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(remote_session_id)
def test_max_remote_sessions_same_user(client_server_network):
client, server, network = client_server_network
remote_session_ids = [
server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10")
for _ in range(server.user_session_manager.max_remote_sessions)
]
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
def test_max_remote_sessions_different_users(client_server_network):
client, server, network = client_server_network
remote_session_ids = []
for i in range(server.user_session_manager.max_remote_sessions):
username = str(uuid4())
password = "12345"
server.user_manager.add_user(username=username, password=password)
remote_session_ids.append(
server.user_session_manager.remote_login(
username=username, password=password, remote_ip_address="192.168.1.10"
)
)
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
def test_max_remote_sessions_limit_reached(client_server_network):
client, server, network = client_server_network
remote_session_ids = [
server.user_session_manager.remote_login(username="admin", password="admin", remote_ip_address="192.168.1.10")
for _ in range(server.user_session_manager.max_remote_sessions)
]
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
assert len(server.user_session_manager.remote_sessions) == server.user_session_manager.max_remote_sessions
fourth_attempt_session_id = server.user_session_manager.remote_login(
username="admin", password="admin", remote_ip_address="192.168.1.10"
)
assert not server.user_session_manager.validate_remote_session_uuid(fourth_attempt_session_id)
assert all([server.user_session_manager.validate_remote_session_uuid(id) for id in remote_session_ids])
def test_single_remote_logout_others_persist(client_server_network):
client, server, network = client_server_network
server.user_manager.add_user(username="jane.doe", password="12345")
server.user_manager.add_user(username="john.doe", password="12345")
admin_session_id = server.user_session_manager.remote_login(
username="admin", password="admin", remote_ip_address="192.168.1.10"
)
jane_session_id = server.user_session_manager.remote_login(
username="jane.doe", password="12345", remote_ip_address="192.168.1.10"
)
john_session_id = server.user_session_manager.remote_login(
username="john.doe", password="12345", remote_ip_address="192.168.1.10"
)
server.user_session_manager.remote_logout(admin_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id)
assert server.user_session_manager.validate_remote_session_uuid(jane_session_id)
assert server.user_session_manager.validate_remote_session_uuid(john_session_id)
server.user_session_manager.remote_logout(jane_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id)
assert server.user_session_manager.validate_remote_session_uuid(john_session_id)
server.user_session_manager.remote_logout(john_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(admin_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(jane_session_id)
assert not server.user_session_manager.validate_remote_session_uuid(john_session_id)

View File

@@ -13,7 +13,7 @@ from primaite.simulator.network.hardware.node_operating_state import NodeOperati
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from tests.conftest import DummyApplication, TestService
from tests.conftest import DummyApplication, DummyService
def test_successful_node_file_system_creation_request(example_network):
@@ -61,7 +61,7 @@ def test_successful_application_requests(example_network):
def test_successful_service_requests(example_network):
net = example_network
server_1 = net.get_node_by_hostname("server_1")
server_1.software_manager.install(TestService)
server_1.software_manager.install(DummyService)
# Careful: the order here is important, for example we cannot run "stop" unless we run "start" first
for verb in [
@@ -77,7 +77,7 @@ def test_successful_service_requests(example_network):
"scan",
"fix",
]:
resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb])
resp_1 = net.apply_request(["node", "server_1", "service", "DummyService", verb])
assert resp_1 == RequestResponse(status="success", data={})
server_1.apply_timestep(timestep=1)
server_1.apply_timestep(timestep=1)

View File

@@ -62,7 +62,6 @@ def test_probabilistic_agent():
reward_function=reward_function,
settings={
"action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE},
"random_seed": 120,
},
)

View File

@@ -7,6 +7,7 @@ from primaite.simulator.file_system.folder import Folder
from primaite.simulator.network.hardware.base import Node, NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.software import SoftwareHealthState
from tests.conftest import DummyApplication, DummyService
@pytest.fixture
@@ -47,7 +48,7 @@ def test_node_shutdown(node):
assert node.operating_state == NodeOperatingState.OFF
def test_node_os_scan(node, service, application):
def test_node_os_scan(node):
"""Test OS Scanning."""
node.operating_state = NodeOperatingState.ON
@@ -55,13 +56,15 @@ def test_node_os_scan(node, service, application):
# TODO implement processes
# add services to node
node.software_manager.install(DummyService)
service = node.software_manager.software.get("DummyService")
service.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_service(service=service)
assert service.health_state_visible == SoftwareHealthState.UNUSED
# add application to node
node.software_manager.install(DummyApplication)
application = node.software_manager.software.get("DummyApplication")
application.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_application(application=application)
assert application.health_state_visible == SoftwareHealthState.UNUSED
# add folder and file to node
@@ -91,7 +94,7 @@ def test_node_os_scan(node, service, application):
assert file2.visible_health_status == FileSystemItemHealthStatus.CORRUPT
def test_node_red_scan(node, service, application):
def test_node_red_scan(node):
"""Test revealing to red"""
node.operating_state = NodeOperatingState.ON
@@ -99,12 +102,14 @@ def test_node_red_scan(node, service, application):
# TODO implement processes
# add services to node
node.install_service(service=service)
node.software_manager.install(DummyService)
service = node.software_manager.software.get("DummyService")
assert service.revealed_to_red is False
# add application to node
node.software_manager.install(DummyApplication)
application = node.software_manager.software.get("DummyApplication")
application.set_health_state(SoftwareHealthState.COMPROMISED)
node.install_application(application=application)
assert application.revealed_to_red is False
# add folder and file to node

View File

@@ -0,0 +1,380 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
from uuid import uuid4
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.protocols.ssh import (
SSHConnectionMessage,
SSHPacket,
SSHTransportMessage,
SSHUserCredentials,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.red_applications.ransomware_script import RansomwareScript
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.terminal.terminal import RemoteTerminalConnection, Terminal
from primaite.simulator.system.services.web_server.web_server import WebServer
@pytest.fixture(scope="function")
def terminal_on_computer() -> Tuple[Terminal, Computer]:
computer: Computer = Computer(
hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0
)
computer.power_on()
terminal: Terminal = computer.software_manager.software.get("Terminal")
return terminal, computer
@pytest.fixture(scope="function")
def basic_network() -> Network:
network = Network()
node_a = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
node_a.power_on()
node_a.software_manager.get_open_ports()
node_b = Computer(hostname="node_b", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0)
node_b.power_on()
network.connect(node_a.network_interface[1], node_b.network_interface[1])
return network
@pytest.fixture(scope="function")
def wireless_wan_network():
network = Network()
# Configure PC A
pc_a = Computer(
hostname="pc_a",
ip_address="192.168.0.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.0.1",
start_up_duration=0,
)
pc_a.power_on()
network.add_node(pc_a)
# Configure Router 1
router_1 = WirelessRouter(hostname="router_1", start_up_duration=0, airspace=network.airspace)
router_1.power_on()
network.add_node(router_1)
# Configure the connection between PC A and Router 1 port 2
router_1.configure_router_interface("192.168.0.1", "255.255.255.0")
network.connect(pc_a.network_interface[1], router_1.network_interface[2])
# Configure Router 1 ACLs
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
# add ACL rule to allow SSH traffic
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21)
# Configure PC B
pc_b = Computer(
hostname="pc_b",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
start_up_duration=0,
)
pc_b.power_on()
network.add_node(pc_b)
# Configure Router 2
router_2 = WirelessRouter(hostname="router_2", start_up_duration=0, airspace=network.airspace)
router_2.power_on()
network.add_node(router_2)
# Configure the connection between PC B and Router 2 port 2
router_2.configure_router_interface("192.168.2.1", "255.255.255.0")
network.connect(pc_b.network_interface[1], router_2.network_interface[2])
# Configure Router 2 ACLs
# Configure the wireless connection between Router 1 port 1 and Router 2 port 1
router_1.configure_wireless_access_point("192.168.1.1", "255.255.255.0")
router_2.configure_wireless_access_point("192.168.1.2", "255.255.255.0")
router_1.route_table.add_route(
address="192.168.2.0", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.2"
)
# Configure Route from Router 2 to PC A subnet
router_2.route_table.add_route(
address="192.168.0.2", subnet_mask="255.255.255.0", next_hop_ip_address="192.168.1.1"
)
return pc_a, pc_b, router_1, router_2
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return game, agent
def test_terminal_creation(terminal_on_computer):
terminal, computer = terminal_on_computer
terminal.describe_state()
def test_terminal_install_default():
"""Terminal should be auto installed onto Nodes"""
computer = Computer(hostname="node_a", ip_address="192.168.0.10", subnet_mask="255.255.255.0", start_up_duration=0)
computer.power_on()
assert computer.software_manager.software.get("Terminal")
def test_terminal_not_on_switch():
"""Ensure terminal does not auto-install to switch"""
test_switch = Switch(hostname="Test")
assert not test_switch.software_manager.software.get("Terminal")
def test_terminal_send(basic_network):
"""Test that Terminal can send valid commands."""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
payload: SSHPacket = SSHPacket(
payload="Test_Payload",
transport_message=SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
user_account=SSHUserCredentials(username="username", password="password"),
connection_request_uuid=str(uuid4()),
)
assert terminal_a.send(payload=payload, dest_ip_address=computer_b.network_interface[1].ip_address)
def test_terminal_receive(basic_network):
"""Test that terminal can receive and process commands"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
folder_name = "Downloads"
payload: SSHPacket = SSHPacket(
payload=["file_system", "create", "folder", folder_name],
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
)
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
username="username", password="password", ip_address="192.168.0.11"
)
term_a_on_node_b.execute(["file_system", "create", "folder", folder_name])
# Assert that the Folder has been correctly created
assert computer_b.file_system.get_folder(folder_name)
def test_terminal_install(basic_network):
"""Test that Terminal can successfully process an INSTALL request"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
payload: SSHPacket = SSHPacket(
payload=["software_manager", "application", "install", "RansomwareScript"],
transport_message=SSHTransportMessage.SSH_MSG_SERVICE_REQUEST,
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_OPEN,
)
term_a_on_node_b: RemoteTerminalConnection = terminal_a.login(
username="username", password="password", ip_address="192.168.0.11"
)
term_a_on_node_b.execute(["software_manager", "application", "install", "RansomwareScript"])
assert computer_b.software_manager.software.get("RansomwareScript")
def test_terminal_fail_when_closed(basic_network):
"""Ensure Terminal won't attempt to send/receive when off"""
network: Network = basic_network
computer: Computer = network.get_node_by_hostname("node_a")
terminal: Terminal = computer.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal.operating_state = ServiceOperatingState.STOPPED
assert not terminal.login(
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
)
def test_terminal_disconnect(basic_network):
"""Test Terminal disconnects"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
assert len(terminal_b._connections) == 0
term_a_on_term_b = terminal_a.login(
username="admin", password="Admin123!", ip_address=computer_b.network_interface[1].ip_address
)
assert len(terminal_b._connections) == 1
term_a_on_term_b.disconnect()
assert len(terminal_b._connections) == 0
def test_terminal_ignores_when_off(basic_network):
"""Terminal should ignore commands when not running"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
term_a_on_term_b: RemoteTerminalConnection = terminal_a.login(
username="admin", password="Admin123!", ip_address="192.168.0.11"
) # login to computer_b
terminal_a.operating_state = ServiceOperatingState.STOPPED
assert not term_a_on_term_b.execute(["software_manager", "application", "install", "RansomwareScript"])
def test_computer_remote_login_to_router(wireless_wan_network):
"""Test to confirm that a computer can SSH into a router."""
pc_a, _, router_1, _ = wireless_wan_network
pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal")
assert len(pc_a_terminal._connections) == 0
pc_a_on_router_1 = pc_a_terminal.login(username="username", password="password", ip_address="192.168.1.1")
assert len(pc_a_terminal._connections) == 1
payload = ["software_manager", "application", "install", "RansomwareScript"]
pc_a_on_router_1.execute(payload)
assert router_1.software_manager.software.get("RansomwareScript")
def test_router_remote_login_to_computer(wireless_wan_network):
"""Test to confirm that a router can ssh into a computer."""
pc_a, _, router_1, _ = wireless_wan_network
router_1_terminal: Terminal = router_1.software_manager.software.get("Terminal")
assert len(router_1_terminal._connections) == 0
router_1_on_pc_a = router_1_terminal.login(username="username", password="password", ip_address="192.168.0.2")
assert len(router_1_terminal._connections) == 1
payload = ["software_manager", "application", "install", "RansomwareScript"]
router_1_on_pc_a.execute(payload)
assert pc_a.software_manager.software.get("RansomwareScript")
def test_router_blocks_SSH_traffic(wireless_wan_network):
"""Test to check that router will block SSH traffic if no ACL rule."""
pc_a, _, router_1, _ = wireless_wan_network
# Remove rule that allows SSH traffic.
router_1.acl.remove_rule(position=21)
pc_a_terminal: Terminal = pc_a.software_manager.software.get("Terminal")
assert len(pc_a_terminal._connections) == 0
pc_a_terminal.login(username="username", password="password", ip_address="192.168.0.2")
assert len(pc_a_terminal._connections) == 0
def test_SSH_across_network(wireless_wan_network):
"""Test to show ability to SSH across a network."""
pc_a, pc_b, router_1, router_2 = wireless_wan_network
terminal_a: Terminal = pc_a.software_manager.software.get("Terminal")
terminal_b: Terminal = pc_b.software_manager.software.get("Terminal")
router_2.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.SSH, dst_port=Port.SSH, position=21)
assert len(terminal_a._connections) == 0
terminal_b_on_terminal_a = terminal_b.login(username="username", password="password", ip_address="192.168.0.2")
assert len(terminal_a._connections) == 1
def test_multiple_remote_terminals_same_node(basic_network):
"""Test to check that multiple remote terminals can be spawned by one node."""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
assert len(terminal_a._connections) == 0
# Spam login requests to terminal.
for attempt in range(10):
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 10
def test_terminal_rejects_commands_if_disconnect(basic_network):
"""Test to check terminal will ignore commands from disconnected connections"""
network: Network = basic_network
computer_a: Computer = network.get_node_by_hostname("node_a")
terminal_a: Terminal = computer_a.software_manager.software.get("Terminal")
computer_b: Computer = network.get_node_by_hostname("node_b")
terminal_b: Terminal = computer_b.software_manager.software.get("Terminal")
remote_connection = terminal_a.login(username="username", password="password", ip_address="192.168.0.11")
assert len(terminal_a._connections) == 1
assert len(terminal_b._connections) == 1
remote_connection.disconnect()
assert len(terminal_a._connections) == 0
assert len(terminal_b._connections) == 0
assert remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"]) is False
assert not computer_b.software_manager.software.get("RansomwareScript")
assert remote_connection.is_active is False