Merge remote-tracking branch 'origin/dev' into 4.0.0-dev
This commit is contained in:
10
CHANGELOG.md
10
CHANGELOG.md
@@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
## [4.0.0] = TBC
|
||||
|
||||
### Added
|
||||
- Log observation space data by episode and step.
|
||||
- Added `show_history` method to Agents, allowing you to view actions taken by an agent per step. By default, `do-nothing` actions are omitted.
|
||||
- New ``node-send-local-command`` action implemented which grants agents the ability to execute commands locally. (Previously limited to remote only)
|
||||
- Added ability to set the observation threshold for NMNE, file access and application executions
|
||||
|
||||
### Changed
|
||||
- Agents now follow a common configuration format, simplifying the configuration of agents and their extensibilty.
|
||||
@@ -24,6 +28,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Updated tests that don't use YAMLs to still use the new action and agent schemas
|
||||
- Nodes now use a config schema and are extensible, allowing for plugin support.
|
||||
- Node tests have been updated to use the new node config schemas when not using YAML files.
|
||||
- ACLs are no longer applied to layer-2 traffic.
|
||||
- Random number seed values are recorded in simulation/seed.log if the seed is set in the config file
|
||||
or `generate_seed_value` is set to `true`.
|
||||
- ARP .show() method will now include the port number associated with each entry.
|
||||
- Added `services_requires_scan` and `applications_requires_scan` to agent observation space config to allow the agents to be able to see actual health states of services and applications without requiring scans (Default `True`, set to `False` to allow agents to see actual health state without scanning).
|
||||
- Updated the `Terminal` class to provide response information when sending remote command execution.
|
||||
|
||||
### Fixed
|
||||
- DNS client no longer fails to check its cache if a DNS server address is missing.
|
||||
|
||||
@@ -21,7 +21,7 @@ Agents can be scripted (deterministic and stochastic), or controlled by a reinfo
|
||||
team: GREEN
|
||||
type: probabilistic-agent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
type: UC2GreenObservation # TODO: what
|
||||
action_space:
|
||||
reward_function:
|
||||
reward_components:
|
||||
@@ -160,3 +160,4 @@ If ``True``, gymnasium flattening will be performed on the observation space bef
|
||||
-----------------
|
||||
|
||||
Agents will record their action log for each step. This is a summary of what the agent did, along with response information from requests within the simulation.
|
||||
A summary of the actions taken by the agent can be viewed using the `show_history()` function. By default, this will display all actions taken apart from ``DONOTHING``.
|
||||
|
||||
@@ -54,6 +54,39 @@ Optional. Default value is ``3``.
|
||||
|
||||
The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``.
|
||||
|
||||
``file_system``
|
||||
---------------
|
||||
|
||||
Optional.
|
||||
|
||||
The file system of the node. This configuration allows nodes to be initialised with files and/or folders.
|
||||
|
||||
The file system takes a list of folders and files.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.11
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
file_system:
|
||||
- empty_folder # example of an empty folder
|
||||
- downloads:
|
||||
- "test_1.txt" # files in the downloads folder
|
||||
- "test_2.txt"
|
||||
- root:
|
||||
- passwords: # example of file with size and type
|
||||
size: 69 # size in bytes
|
||||
type: TXT # See FileType for list of available file types
|
||||
|
||||
List of file types: :py:mod:`primaite.simulator.file_system.file_type.FileType`
|
||||
|
||||
``users``
|
||||
---------
|
||||
|
||||
|
||||
@@ -1177,8 +1177,8 @@ ACLs permitting or denying traffic as per our configured ACL rules.
|
||||
some_tech_storage_srv = network.get_node_by_hostname("some_tech_storage_srv")
|
||||
some_tech_storage_srv.file_system.create_file(file_name="test.png")
|
||||
|
||||
pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["FTPClient"]
|
||||
pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["FTPClient"]
|
||||
pc_1_ftp_client: FTPClient = network.get_node_by_hostname("pc_1").software_manager.software["ftp-client"]
|
||||
pc_2_ftp_client: FTPClient = network.get_node_by_hostname("pc_2").software_manager.software["ftp-client"]
|
||||
|
||||
assert not pc_1_ftp_client.request_file(
|
||||
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
|
||||
@@ -1224,7 +1224,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
|
||||
|
||||
web_server: Server = network.get_node_by_hostname("some_tech_web_srv")
|
||||
|
||||
web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"]
|
||||
web_ftp_client: FTPClient = web_server.software_manager.software["ftp-client"]
|
||||
|
||||
assert not web_ftp_client.request_file(
|
||||
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
|
||||
@@ -1269,7 +1269,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
|
||||
some_tech_storage_srv.file_system.create_file(file_name="test.png")
|
||||
|
||||
some_tech_snr_dev_pc: Computer = network.get_node_by_hostname("some_tech_snr_dev_pc")
|
||||
snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["FTPClient"]
|
||||
snr_dev_ftp_client: FTPClient = some_tech_snr_dev_pc.software_manager.software["ftp-client"]
|
||||
|
||||
assert snr_dev_ftp_client.request_file(
|
||||
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
|
||||
@@ -1294,7 +1294,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
|
||||
some_tech_storage_srv.file_system.create_file(file_name="test.png")
|
||||
|
||||
some_tech_jnr_dev_pc: Computer = network.get_node_by_hostname("some_tech_jnr_dev_pc")
|
||||
jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["FTPClient"]
|
||||
jnr_dev_ftp_client: FTPClient = some_tech_jnr_dev_pc.software_manager.software["ftp-client"]
|
||||
|
||||
assert not jnr_dev_ftp_client.request_file(
|
||||
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
|
||||
@@ -1337,7 +1337,7 @@ ACLs permitting or denying traffic as per our configured ACL rules.
|
||||
some_tech_storage_srv.file_system.create_file(file_name="test.png")
|
||||
|
||||
some_tech_hr_pc: Computer = network.get_node_by_hostname("some_tech_hr_1")
|
||||
hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["FTPClient"]
|
||||
hr_ftp_client: FTPClient = some_tech_hr_pc.software_manager.software["ftp-client"]
|
||||
|
||||
assert not hr_ftp_client.request_file(
|
||||
dest_ip_address=some_tech_storage_srv.network_interface[1].ip_address,
|
||||
|
||||
@@ -74,7 +74,7 @@ The subnet mask setting for the port.
|
||||
``acl``
|
||||
-------
|
||||
|
||||
Sets up the ACL rules for the router.
|
||||
Sets up the ACL rules for the router to apply to layer-3 traffic. These are not applied to layer-2 traffic such as ARP.
|
||||
|
||||
e.g.
|
||||
|
||||
@@ -85,10 +85,6 @@ e.g.
|
||||
...
|
||||
acl:
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
2:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
|
||||
@@ -46,17 +46,13 @@ The core features that should be implemented in any new agent are detailed below
|
||||
|
||||
- ref: example_green_agent
|
||||
team: GREEN
|
||||
type: ExampleAgent
|
||||
type: example-agent
|
||||
|
||||
action_space:
|
||||
action_map:
|
||||
0:
|
||||
action: do-nothing
|
||||
options: {}
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: dummy
|
||||
|
||||
agent_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
|
||||
@@ -26,9 +26,9 @@ class Router(NetworkNode, identifier="router"):
|
||||
""" Represents a network router within the simulation, managing routing and forwarding of IP packets across network interfaces."""
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
"Terminal": Terminal,
|
||||
"user-session-manager": UserSessionManager,
|
||||
"user-manager": UserManager,
|
||||
"terminal": Terminal,
|
||||
}
|
||||
|
||||
network_interfaces: Dict[str, RouterInterface] = {}
|
||||
@@ -52,4 +52,4 @@ class Router(NetworkNode, identifier="router"):
|
||||
Changes to YAML file.
|
||||
=====================
|
||||
|
||||
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.
|
||||
While effort has been made to ensure that nodes defined within configuration YAML files for use with PrimAITE 3.X remain compatible with PrimAITE v4+, it is encouraged to review for minor changes needed.
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
© Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _request_system:
|
||||
|
||||
Request System
|
||||
**************
|
||||
|
||||
|
||||
@@ -97,19 +97,19 @@ we'll use the following Network that has a client, server, two switches, and a r
|
||||
network.connect(endpoint_a=switch_2.network_interface[1], endpoint_b=client_1.network_interface[1])
|
||||
network.connect(endpoint_a=switch_1.network_interface[1], endpoint_b=server_1.network_interface[1])
|
||||
|
||||
8. Add ACL rules on the Router to allow ARP and ICMP traffic.
|
||||
8. Add an ACL rule on the Router to allow ICMP traffic.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router_1.acl.add_rule(
|
||||
action=ACLAction.PERMIT,
|
||||
src_port=Port["ARP"],
|
||||
dst_port=Port["ARP"],
|
||||
src_port=PORT_LOOKUP["ARP"],
|
||||
dst_port=PORT_LOOKUP["ARP"],
|
||||
position=22
|
||||
)
|
||||
|
||||
router_1.acl.add_rule(
|
||||
action=ACLAction.PERMIT,
|
||||
protocol=IPProtocol["ICMP"],
|
||||
protocol=PROTOCOL_LOOKUP["ICMP"],
|
||||
position=23
|
||||
)
|
||||
|
||||
@@ -102,8 +102,8 @@ ICMP traffic, ensuring basic network connectivity and ping functionality.
|
||||
network.connect(pc_a.network_interface[1], router_1.router_interface)
|
||||
|
||||
# Configure Router 1 ACLs
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port["ARP"], dst_port=Port["ARP"], position=22)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol["ICMP"], position=23)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=22)
|
||||
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=PROTOCOL_LOOKUP["ICMP"], position=23)
|
||||
|
||||
# Configure PC B
|
||||
pc_b = Computer(
|
||||
|
||||
@@ -183,7 +183,7 @@ Python
|
||||
# Example command: Installing and configuring Ransomware:
|
||||
|
||||
ransomware_installation_command = { "commands": [
|
||||
["software_manager","application","install","RansomwareScript"],
|
||||
["software_manager","application","install","ransomware-script"],
|
||||
],
|
||||
"username": "admin",
|
||||
"password": "admin",
|
||||
|
||||
@@ -77,7 +77,7 @@ Python
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
client_1.software_manager.install(DataManipulationBot)
|
||||
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
|
||||
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot")
|
||||
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
|
||||
data_manipulation_bot.run()
|
||||
|
||||
@@ -98,7 +98,7 @@ If not using the data manipulation bot manually, it needs to be used with a data
|
||||
type: red-database-corrupting-agent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
type: uc2-red-observation #TODO what
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
|
||||
@@ -59,7 +59,7 @@ Python
|
||||
# install DatabaseClient
|
||||
client.software_manager.install(DatabaseClient)
|
||||
|
||||
database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient")
|
||||
database_client: DatabaseClient = client.software_manager.software.get("database-sclient")
|
||||
|
||||
# Configure the DatabaseClient
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService
|
||||
|
||||
@@ -62,7 +62,7 @@ Python
|
||||
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
|
||||
client_1.software_manager.install(DatabaseClient)
|
||||
client_1.software_manager.install(RansomwareScript)
|
||||
RansomwareScript: RansomwareScript = client_1.software_manager.software.get("RansomwareScript")
|
||||
RansomwareScript: RansomwareScript = client_1.software_manager.software.get("ransomware-script")
|
||||
RansomwareScript.configure(server_ip_address=IPv4Address("192.168.1.14"))
|
||||
RansomwareScript.execute()
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`D
|
||||
|
||||
# Install WebBrowser on computer
|
||||
computer.software_manager.install(WebBrowser)
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
|
||||
web_browser: WebBrowser = computer.software_manager.software.get("web-browser")
|
||||
web_browser.run()
|
||||
|
||||
# configure the WebBrowser
|
||||
|
||||
@@ -66,7 +66,7 @@ Python
|
||||
|
||||
# Install DatabaseService on server
|
||||
server.software_manager.install(DatabaseService)
|
||||
db_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
db_service: DatabaseService = server.software_manager.software.get("database-service")
|
||||
db_service.start()
|
||||
|
||||
# configure DatabaseService
|
||||
|
||||
@@ -56,7 +56,7 @@ Python
|
||||
|
||||
# Install DNSClient on server
|
||||
server.software_manager.install(DNSClient)
|
||||
dns_client: DNSClient = server.software_manager.software.get("DNSClient")
|
||||
dns_client: DNSClient = server.software_manager.software.get("dns-client")
|
||||
dns_client.start()
|
||||
|
||||
# configure DatabaseService
|
||||
|
||||
@@ -53,7 +53,7 @@ Python
|
||||
|
||||
# Install DNSServer on server
|
||||
server.software_manager.install(DNSServer)
|
||||
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
|
||||
dns_server: DNSServer = server.software_manager.software.get("dns-server")
|
||||
dns_server.start()
|
||||
|
||||
# configure DatabaseService
|
||||
|
||||
@@ -60,7 +60,7 @@ Python
|
||||
|
||||
# Install FTPClient on server
|
||||
server.software_manager.install(FTPClient)
|
||||
ftp_client: FTPClient = server.software_manager.software.get("FTPClient")
|
||||
ftp_client: FTPClient = server.software_manager.software.get("ftp-client")
|
||||
ftp_client.start()
|
||||
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ Python
|
||||
|
||||
# Install FTPServer on server
|
||||
server.software_manager.install(FTPServer)
|
||||
ftp_server: FTPServer = server.software_manager.software.get("FTPServer")
|
||||
ftp_server: FTPServer = server.software_manager.software.get("ftp-server")
|
||||
ftp_server.start()
|
||||
|
||||
ftp_server.server_password = "test"
|
||||
|
||||
@@ -53,7 +53,7 @@ Python
|
||||
|
||||
# Install NTPClient on server
|
||||
server.software_manager.install(NTPClient)
|
||||
ntp_client: NTPClient = server.software_manager.software.get("NTPClient")
|
||||
ntp_client: NTPClient = server.software_manager.software.get("ntp-client")
|
||||
ntp_client.start()
|
||||
|
||||
ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10"))
|
||||
|
||||
@@ -55,7 +55,7 @@ Python
|
||||
|
||||
# Install NTPServer on server
|
||||
server.software_manager.install(NTPServer)
|
||||
ntp_server: NTPServer = server.software_manager.software.get("NTPServer")
|
||||
ntp_server: NTPServer = server.software_manager.software.get("ntp-server")
|
||||
ntp_server.start()
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,14 @@ Key capabilities
|
||||
- Simulates common Terminal processes/commands.
|
||||
- Leverages the Service base class for install/uninstall, status tracking etc.
|
||||
|
||||
Usage
|
||||
"""""
|
||||
|
||||
- Pre-Installs on any `Node` component (with the exception of `Switches`).
|
||||
- Terminal Clients connect, execute commands and disconnect from remote nodes.
|
||||
- Ensures that users are logged in to the component before executing any commands.
|
||||
- Service runs on SSH port 22 by default.
|
||||
- Enables Agents to send commands both remotely and locally.
|
||||
|
||||
Implementation
|
||||
""""""""""""""
|
||||
@@ -30,19 +38,112 @@ Implementation
|
||||
- Manages remote connections in a dictionary by session ID.
|
||||
- Processes commands, forwarding to the ``RequestManager`` or ``SessionManager`` where appropriate.
|
||||
- Extends Service class.
|
||||
- A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
|
||||
|
||||
A detailed guide on the implementation and functionality of the Terminal class can be found in the "Terminal-Processing" jupyter notebook.
|
||||
|
||||
Command Format
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Terminals implement their commands through leveraging the pre-existing :ref:`request_system`.
|
||||
|
||||
Due to this Terminals will only accept commands passed within the ``RequestFormat``.
|
||||
|
||||
:py:class:`primaite.game.interface.RequestFormat`
|
||||
|
||||
For example, ``terminal`` command actions when used in ``yaml`` format are formatted as follows:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
command:
|
||||
- "file_system"
|
||||
- "create"
|
||||
- "file"
|
||||
- "downloads"
|
||||
- "cat.png"
|
||||
- "False
|
||||
|
||||
This is then loaded from yaml into a dictionary containing the terminal command:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{"command":["file_system", "create", "file", "downloads", "cat.png", "False"]}
|
||||
|
||||
Which is then passed to the ``Terminals`` Request Manager to be executed.
|
||||
|
||||
Game Layer Usage (Agents)
|
||||
========================
|
||||
|
||||
The below code examples demonstrate how to use terminal related actions in yaml files.
|
||||
|
||||
yaml
|
||||
""""
|
||||
|
||||
``node-send-local-command``
|
||||
"""""""""""""""""""""""""""
|
||||
|
||||
Agents can execute local commands without needing to perform a separate remote login action (``node-session-remote-login``).
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
...
|
||||
...
|
||||
action: node-send-local-command
|
||||
options:
|
||||
node_id: 0
|
||||
username: admin
|
||||
password: admin
|
||||
command: # Example command - Creates a file called 'cat.png' in the downloads folder.
|
||||
- "file_system"
|
||||
- "create"
|
||||
- "file"
|
||||
- "downloads"
|
||||
- "cat.png"
|
||||
- "False"
|
||||
|
||||
|
||||
Usage
|
||||
"""""
|
||||
``node-session-remote-login``
|
||||
"""""""""""""""""
|
||||
|
||||
- Pre-Installs on all ``Nodes`` (with the exception of ``Switches``).
|
||||
- Terminal Clients connect, execute commands and disconnect from remote nodes.
|
||||
- Ensures that users are logged in to the component before executing any commands.
|
||||
- Service runs on SSH port 22 by default.
|
||||
Agents are able to use the terminal to login into remote nodes via ``SSH`` which allows for agents to execute commands on remote hosts.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
...
|
||||
...
|
||||
action: node-session-remote-login
|
||||
options:
|
||||
node_id: 0
|
||||
username: admin
|
||||
password: admin
|
||||
remote_ip: 192.168.0.10 # Example Ip Address. (The remote host's IP that will be used by ssh)
|
||||
|
||||
|
||||
``node-send-remote-command``
|
||||
""""""""""""""""""""""""""""
|
||||
|
||||
After remotely logging into another host, an agent can use the ``node-send-remote-command`` to execute commands across the network remotely.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
...
|
||||
...
|
||||
action: node-send-remote-command
|
||||
options:
|
||||
node_id: 0
|
||||
remote_ip: 192.168.0.10
|
||||
command:
|
||||
- "file_system"
|
||||
- "create"
|
||||
- "file"
|
||||
- "downloads"
|
||||
- "cat.png"
|
||||
- "False"
|
||||
|
||||
|
||||
|
||||
Simulation Layer Usage
|
||||
======================
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
The below code examples demonstrate how to create a terminal, a remote terminal, and how to send a basic application install command to a remote node.
|
||||
|
||||
@@ -65,7 +166,7 @@ Python
|
||||
operating_state=NodeOperatingState.ON,
|
||||
)
|
||||
|
||||
terminal: Terminal = client.software_manager.software.get("Terminal")
|
||||
terminal: Terminal = client.software_manager.software.get("terminal")
|
||||
|
||||
Creating Remote Terminal Connection
|
||||
"""""""""""""""""""""""""""""""""""
|
||||
@@ -86,7 +187,7 @@ Creating Remote Terminal Connection
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
@@ -112,12 +213,12 @@ Executing a basic application install command
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
|
||||
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "RansomwareScript"])
|
||||
term_a_term_b_remote_connection.execute(["software_manager", "application", "install", "ransomware-script"])
|
||||
|
||||
|
||||
|
||||
@@ -140,7 +241,7 @@ Creating a folder on a remote node
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
@@ -167,7 +268,7 @@ Disconnect from Remote Node
|
||||
node_b.power_on()
|
||||
network.connect(node_a.network_interface[1], node_b.network_interface[1])
|
||||
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("Terminal")
|
||||
terminal_a: Terminal = node_a.software_manager.software.get("terminal")
|
||||
|
||||
|
||||
term_a_term_b_remote_connection: RemoteTerminalConnection = terminal_a.login(username="admin", password="Admin123!", ip_address="192.168.0.11")
|
||||
|
||||
@@ -56,7 +56,7 @@ Python
|
||||
|
||||
# Install WebServer on server
|
||||
server.software_manager.install(WebServer)
|
||||
web_server: WebServer = server.software_manager.software.get("WebServer")
|
||||
web_server: WebServer = server.software_manager.software.get("web-server")
|
||||
web_server.start()
|
||||
|
||||
Via Configuration
|
||||
|
||||
@@ -30,7 +30,7 @@ See :ref:`Node Start up and Shut down`
|
||||
|
||||
node.software_manager.install(WebServer)
|
||||
|
||||
web_server: WebServer = node.software_manager.software.get("WebServer")
|
||||
web_server: WebServer = node.software_manager.software.get("web-server")
|
||||
assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install
|
||||
|
||||
node.power_off()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
from typing import ClassVar, List, Literal, Optional, Union
|
||||
|
||||
from primaite.game.agent.actions.manager import AbstractAction
|
||||
from primaite.interface.request import RequestFormat
|
||||
@@ -153,8 +153,6 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, discriminator="node-nmap-po
|
||||
class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-network-service-recon"):
|
||||
"""Action which performs an nmap network service recon (ping scan followed by port scan)."""
|
||||
|
||||
config: "NodeNetworkServiceReconAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeNetworkServiceReconAction."""
|
||||
|
||||
@@ -179,3 +177,70 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, discriminator="node-
|
||||
"show": config.show,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class NodeAccountsAddUserAction(AbstractAction, discriminator="node-account-add-user"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["node-account-add-user"] = "node-account-add-user"
|
||||
node_name: str
|
||||
username: str
|
||||
password: str
|
||||
is_admin: bool
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def form_request(config: ConfigSchema) -> RequestFormat:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"user-manager",
|
||||
"add_user",
|
||||
config.username,
|
||||
config.password,
|
||||
config.is_admin,
|
||||
]
|
||||
|
||||
|
||||
class NodeAccountsDisableUserAction(AbstractAction, discriminator="node-account-disable-user"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["node-account-disable-user"] = "node-account-disable-user"
|
||||
node_name: str
|
||||
username: str
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def form_request(config: ConfigSchema) -> RequestFormat:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"user-manager",
|
||||
"disable_user",
|
||||
config.username,
|
||||
]
|
||||
|
||||
|
||||
class NodeSendLocalCommandAction(AbstractAction, discriminator="node-send-local-command"):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
type: Literal["node-send-local-command"] = "node-send-local-command"
|
||||
node_name: str
|
||||
username: str
|
||||
password: str
|
||||
command: RequestFormat
|
||||
|
||||
@staticmethod
|
||||
def form_request(config: ConfigSchema) -> RequestFormat:
|
||||
return [
|
||||
"network",
|
||||
"node",
|
||||
config.node_name,
|
||||
"service",
|
||||
"terminal",
|
||||
"send_local_command",
|
||||
config.username,
|
||||
config.password,
|
||||
{"command": config.command},
|
||||
]
|
||||
|
||||
@@ -34,8 +34,6 @@ class NodeSessionAbstractAction(AbstractAction, ABC):
|
||||
class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="node-session-remote-login"):
|
||||
"""Action which performs a remote session login."""
|
||||
|
||||
config: "NodeSessionsRemoteLoginAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLoginAction."""
|
||||
|
||||
@@ -53,7 +51,7 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no
|
||||
config.node_name,
|
||||
"service",
|
||||
"terminal",
|
||||
"node-session-remote-login",
|
||||
"node_session_remote_login",
|
||||
config.username,
|
||||
config.password,
|
||||
config.remote_ip,
|
||||
@@ -63,8 +61,6 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, discriminator="no
|
||||
class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="node-session-remote-logoff"):
|
||||
"""Action which performs a remote session logout."""
|
||||
|
||||
config: "NodeSessionsRemoteLogoutAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeSessionsRemoteLogoutAction."""
|
||||
|
||||
@@ -78,14 +74,13 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, discriminator="n
|
||||
return ["network", "node", config.node_name, "service", "terminal", config.verb, config.remote_ip]
|
||||
|
||||
|
||||
class NodeAccountChangePasswordAction(NodeSessionAbstractAction, discriminator="node-account-change-password"):
|
||||
class NodeAccountChangePasswordAction(AbstractAction, discriminator="node-account-change-password"):
|
||||
"""Action which changes the password for a user."""
|
||||
|
||||
config: "NodeAccountChangePasswordAction.ConfigSchema"
|
||||
|
||||
class ConfigSchema(NodeSessionAbstractAction.ConfigSchema):
|
||||
class ConfigSchema(AbstractAction.ConfigSchema):
|
||||
"""Configuration schema for NodeAccountsChangePasswordAction."""
|
||||
|
||||
node_name: str
|
||||
username: str
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from prettytable import PrettyTable
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
@@ -42,6 +43,9 @@ class AgentHistoryItem(BaseModel):
|
||||
|
||||
reward_info: Dict[str, Any] = {}
|
||||
|
||||
observation: Optional[ObsType] = None
|
||||
"""The observation space data for this step."""
|
||||
|
||||
|
||||
class AbstractAgent(BaseModel, ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
@@ -67,6 +71,9 @@ class AbstractAgent(BaseModel, ABC):
|
||||
default_factory=lambda: ObservationManager.ConfigSchema()
|
||||
)
|
||||
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
|
||||
thresholds: Optional[Dict] = {}
|
||||
# TODO: this is only relevant to some observations, need to refactor the way thresholds are dealt with (#3085)
|
||||
"""A dict containing the observation thresholds."""
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
|
||||
|
||||
@@ -90,10 +97,42 @@ class AbstractAgent(BaseModel, ABC):
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Overwrite the default empty action, observation, and rewards with ones defined through the config."""
|
||||
self.action_manager = ActionManager(config=self.config.action_space)
|
||||
self.config.observation_space.options.thresholds = self.config.thresholds
|
||||
self.observation_manager = ObservationManager(config=self.config.observation_space)
|
||||
self.reward_function = RewardFunction(config=self.config.reward_function)
|
||||
return super().model_post_init(__context)
|
||||
|
||||
def add_agent_action(self, item: AgentHistoryItem, table: PrettyTable) -> PrettyTable:
|
||||
"""Update the given table with information from given AgentHistoryItem."""
|
||||
node, application = "unknown", "unknown"
|
||||
if (node_id := item.parameters.get("node_id")) is not None:
|
||||
node = self.action_manager.node_names[node_id]
|
||||
if (application_id := item.parameters.get("application_id")) is not None:
|
||||
application = self.action_manager.application_names[node_id][application_id]
|
||||
if (application_name := item.parameters.get("application_name")) is not None:
|
||||
application = application_name
|
||||
table.add_row([item.timestep, item.action, node, application, item.response.status])
|
||||
return table
|
||||
|
||||
def show_history(self, ignored_actions: Optional[list] = None):
|
||||
"""
|
||||
Print an agent action provided it's not the DONOTHING action.
|
||||
|
||||
:param ignored_actions: OPTIONAL: List of actions to be ignored when displaying the history.
|
||||
If not provided, defaults to ignore DONOTHING actions.
|
||||
"""
|
||||
if not ignored_actions:
|
||||
ignored_actions = ["DONOTHING"]
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Step", "Action", "Node", "Application", "Response"]
|
||||
print(f"Actions for '{self.agent_name}':")
|
||||
for item in self.history:
|
||||
if item.action in ignored_actions:
|
||||
pass
|
||||
else:
|
||||
table = self.add_agent_action(item=item, table=table)
|
||||
print(table)
|
||||
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
@@ -140,12 +179,23 @@ class AbstractAgent(BaseModel, ABC):
|
||||
return request
|
||||
|
||||
def process_action_response(
|
||||
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
|
||||
self,
|
||||
timestep: int,
|
||||
action: str,
|
||||
parameters: Dict[str, Any],
|
||||
request: RequestFormat,
|
||||
response: RequestResponse,
|
||||
observation: ObsType,
|
||||
) -> None:
|
||||
"""Process the response from the most recent action."""
|
||||
self.history.append(
|
||||
AgentHistoryItem(
|
||||
timestep=timestep, action=action, parameters=parameters, request=request, response=response
|
||||
timestep=timestep,
|
||||
action=action,
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
observation=observation,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,13 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
file_system_requires_scan: Optional[bool] = None
|
||||
"""If True, the file must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool, file_system_requires_scan: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a file observation instance.
|
||||
|
||||
@@ -48,10 +54,36 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("file_access") is None:
|
||||
self.low_file_access_threshold = 0
|
||||
self.med_file_access_threshold = 5
|
||||
self.high_file_access_threshold = 10
|
||||
else:
|
||||
self._set_file_access_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("file_access")["low"],
|
||||
thresholds.get("file_access")["medium"],
|
||||
thresholds.get("file_access")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
def _set_file_access_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the file access threshold.
|
||||
|
||||
:param: thresholds: The file access threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="file_access",
|
||||
):
|
||||
self.low_file_access_threshold = thresholds[0]
|
||||
self.med_file_access_threshold = thresholds[1]
|
||||
self.high_file_access_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
@@ -60,11 +92,11 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
if num_access > self.high_file_access_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
elif num_access > self.med_file_access_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
elif num_access > self.low_file_access_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -122,6 +154,7 @@ class FileObservation(AbstractObservation, discriminator="file"):
|
||||
where=parent_where + ["files", config.file_name],
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +182,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
num_files: int,
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
thresholds: Optional[Dict] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
@@ -177,6 +211,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
where=None,
|
||||
include_num_access=include_num_access,
|
||||
file_system_requires_scan=self.file_system_requires_scan,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
)
|
||||
while len(self.files) > num_files:
|
||||
@@ -253,6 +288,7 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
file_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
file_config.thresholds = config.thresholds
|
||||
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(
|
||||
@@ -261,4 +297,5 @@ class FolderObservation(AbstractObservation, discriminator="folder"):
|
||||
num_files=config.num_files,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
"""
|
||||
If True, files and folders must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = None
|
||||
services_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, services must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
applications_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, applications must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
|
||||
def __init__(
|
||||
@@ -73,6 +81,8 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
monitored_traffic: Optional[Dict],
|
||||
include_num_access: bool,
|
||||
file_system_requires_scan: bool,
|
||||
services_requires_scan: bool,
|
||||
applications_requires_scan: bool,
|
||||
include_users: bool,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -108,6 +118,12 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
:param file_system_requires_scan: If True, the files and folders must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type file_system_requires_scan: bool
|
||||
:param services_requires_scan: If True, services must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type services_requires_scan: bool
|
||||
:param applications_requires_scan: If True, applications must be scanned to update the health state.
|
||||
If False, the true state is always shown.
|
||||
:type applications_requires_scan: bool
|
||||
:param include_users: If True, report user session information.
|
||||
:type include_users: bool
|
||||
"""
|
||||
@@ -121,7 +137,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
self.services.append(ServiceObservation(where=None))
|
||||
self.services.append(ServiceObservation(where=None, services_requires_scan=services_requires_scan))
|
||||
while len(self.services) > num_services:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
@@ -129,7 +145,9 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
|
||||
self.applications: List[ApplicationObservation] = applications
|
||||
while len(self.applications) < num_applications:
|
||||
self.applications.append(ApplicationObservation(where=None))
|
||||
self.applications.append(
|
||||
ApplicationObservation(where=None, applications_requires_scan=applications_requires_scan)
|
||||
)
|
||||
while len(self.applications) > num_applications:
|
||||
truncated_application = self.applications.pop()
|
||||
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
|
||||
@@ -153,7 +171,13 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne, monitored_traffic=monitored_traffic))
|
||||
self.nics.append(
|
||||
NICObservation(
|
||||
where=None,
|
||||
include_nmne=include_nmne,
|
||||
monitored_traffic=monitored_traffic,
|
||||
)
|
||||
)
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
@@ -269,8 +293,15 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
folder_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
folder_config.thresholds = config.thresholds
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
nic_config.thresholds = config.thresholds
|
||||
for service_config in config.services:
|
||||
service_config.services_requires_scan = config.services_requires_scan
|
||||
for application_config in config.applications:
|
||||
application_config.applications_requires_scan = config.applications_requires_scan
|
||||
application_config.thresholds = config.thresholds
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
@@ -281,7 +312,10 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
count = 1
|
||||
while len(nics) < config.num_nics:
|
||||
nic_config = NICObservation.ConfigSchema(
|
||||
nic_num=count, include_nmne=config.include_nmne, monitored_traffic=config.monitored_traffic
|
||||
nic_num=count,
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
nics.append(NICObservation.from_config(config=nic_config, parent_where=where))
|
||||
count += 1
|
||||
@@ -301,5 +335,7 @@ class HostObservation(AbstractObservation, discriminator="host"):
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
include_num_access=config.include_num_access,
|
||||
file_system_requires_scan=config.file_system_requires_scan,
|
||||
services_requires_scan=config.services_requires_scan,
|
||||
applications_requires_scan=config.applications_requires_scan,
|
||||
include_users=config.include_users,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import NMNEConfig
|
||||
from primaite.utils.validation.ip_protocol import IPProtocol
|
||||
from primaite.utils.validation.port import Port
|
||||
|
||||
@@ -15,6 +16,9 @@ from primaite.utils.validation.port import Port
|
||||
class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
|
||||
capture_nmne: ClassVar[bool] = NMNEConfig().capture_nmne
|
||||
"A Boolean specifying whether malicious network events should be captured."
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NICObservation."""
|
||||
|
||||
@@ -25,7 +29,13 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
monitored_traffic: Optional[Dict[IPProtocol, List[Port]]] = None
|
||||
"""A dict containing which traffic types are to be included in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool, monitored_traffic: Optional[Dict] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
monitored_traffic: Optional[Dict] = None,
|
||||
thresholds: Dict = {},
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
@@ -45,10 +55,18 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
if thresholds.get("nmne") is None:
|
||||
self.low_nmne_threshold = 0
|
||||
self.med_nmne_threshold = 5
|
||||
self.high_nmne_threshold = 10
|
||||
else:
|
||||
self._set_nmne_threshold(
|
||||
thresholds=[
|
||||
thresholds.get("nmne")["low"],
|
||||
thresholds.get("nmne")["medium"],
|
||||
thresholds.get("nmne")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
self.monitored_traffic = monitored_traffic
|
||||
if self.monitored_traffic:
|
||||
@@ -105,6 +123,20 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
bandwidth_utilisation = traffic_value / nic_max_bandwidth
|
||||
return int(bandwidth_utilisation * 9) + 1
|
||||
|
||||
def _set_nmne_threshold(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the NMNE threshold.
|
||||
|
||||
:param: thresholds: The NMNE threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=thresholds,
|
||||
threshold_identifier="nmne",
|
||||
):
|
||||
self.low_nmne_threshold = thresholds[0]
|
||||
self.med_nmne_threshold = thresholds[1]
|
||||
self.high_nmne_threshold = thresholds[2]
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
@@ -116,7 +148,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
"""
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
if nic_state is NOT_PRESENT_IN_STATE or self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
@@ -164,7 +196,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
for port in self.monitored_traffic[protocol]:
|
||||
obs["TRAFFIC"][protocol][port] = {"inbound": 0, "outbound": 0}
|
||||
|
||||
if self.include_nmne:
|
||||
if self.capture_nmne and self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
@@ -224,6 +256,7 @@ class NICObservation(AbstractObservation, discriminator="network-interface"):
|
||||
where=parent_where + ["NICs", config.nic_num],
|
||||
include_nmne=config.include_nmne,
|
||||
monitored_traffic=config.monitored_traffic,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,13 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Flag to include the number of accesses."""
|
||||
file_system_requires_scan: bool = True
|
||||
"""If True, the folder must be scanned to update the health state. Tf False, the true state is always shown."""
|
||||
"""If True, the folder must be scanned to update the health state. If False, the true state is always shown."""
|
||||
services_requires_scan: bool = True
|
||||
"""If True, the services must be scanned to update the health state.
|
||||
If False, the true state is always shown."""
|
||||
applications_requires_scan: bool = True
|
||||
"""If True, the applications must be scanned to update the health state.
|
||||
If False, the true state is always shown."""
|
||||
include_users: Optional[bool] = True
|
||||
"""If True, report user session information."""
|
||||
num_ports: Optional[int] = None
|
||||
@@ -196,8 +202,14 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
host_config.include_num_access = config.include_num_access
|
||||
if host_config.file_system_requires_scan is None:
|
||||
host_config.file_system_requires_scan = config.file_system_requires_scan
|
||||
if host_config.services_requires_scan is None:
|
||||
host_config.services_requires_scan = config.services_requires_scan
|
||||
if host_config.applications_requires_scan is None:
|
||||
host_config.applications_requires_scan = config.applications_requires_scan
|
||||
if host_config.include_users is None:
|
||||
host_config.include_users = config.include_users
|
||||
if not host_config.thresholds:
|
||||
host_config.thresholds = config.thresholds
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
@@ -214,6 +226,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
router_config.num_rules = config.num_rules
|
||||
if router_config.include_users is None:
|
||||
router_config.include_users = config.include_users
|
||||
if not router_config.thresholds:
|
||||
router_config.thresholds = config.thresholds
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
@@ -228,6 +242,8 @@ class NodesObservation(AbstractObservation, discriminator="nodes"):
|
||||
firewall_config.num_rules = config.num_rules
|
||||
if firewall_config.include_users is None:
|
||||
firewall_config.include_users = config.include_users
|
||||
if not firewall_config.thresholds:
|
||||
firewall_config.thresholds = config.thresholds
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
|
||||
@@ -114,7 +114,9 @@ class NestedObservation(AbstractObservation, discriminator="custom"):
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options))
|
||||
obs_instance = obs_class.from_config(
|
||||
config=obs_class.ConfigSchema(**component.options, thresholds=config.thresholds)
|
||||
)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
@@ -242,8 +244,5 @@ class ObservationManager(BaseModel):
|
||||
"""
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]))
|
||||
obs_manager = cls(observation)
|
||||
obs_manager = cls(config=config)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -19,6 +19,9 @@ class AbstractObservation(ABC):
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
thresholds: Optional[Dict] = {}
|
||||
"""A dict containing the observation thresholds."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
@@ -69,3 +72,34 @@ class AbstractObservation(ABC):
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
def _validate_thresholds(self, thresholds: List[int] = None, threshold_identifier: Optional[str] = "") -> bool:
|
||||
"""
|
||||
Method that checks if the thresholds are non overlapping and in the correct (ascending) order.
|
||||
|
||||
Pass in the thresholds from low to high e.g.
|
||||
thresholds=[low_threshold, med_threshold, ..._threshold, high_threshold]
|
||||
|
||||
Throws an error if the threshold is not valid
|
||||
|
||||
:param: thresholds: List of thresholds in ascending order.
|
||||
:type: List[int]
|
||||
:param: threshold_identifier: The name of the threshold option.
|
||||
:type: Optional[str]
|
||||
|
||||
:returns: bool
|
||||
"""
|
||||
if thresholds is None or len(thresholds) < 2:
|
||||
raise Exception(f"{threshold_identifier} thresholds are invalid {thresholds}")
|
||||
for idx in range(1, len(thresholds)):
|
||||
if not isinstance(thresholds[idx], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
if not isinstance(thresholds[idx - 1], int):
|
||||
raise Exception(f"{threshold_identifier} threshold ({thresholds[idx]}) is not a valid int.")
|
||||
|
||||
if thresholds[idx] <= thresholds[idx - 1]:
|
||||
raise Exception(
|
||||
f"{threshold_identifier} threshold ({thresholds[idx - 1]}) "
|
||||
f"is greater than or equal to ({thresholds[idx]}.)"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -19,7 +19,10 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
service_name: str
|
||||
"""Name of the service, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
services_requires_scan: Optional[bool] = None
|
||||
"""If True, services must be scanned to update the health state. If False, true state is always shown."""
|
||||
|
||||
def __init__(self, where: WhereType, services_requires_scan: bool) -> None:
|
||||
"""
|
||||
Initialise a service observation instance.
|
||||
|
||||
@@ -28,6 +31,7 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.services_requires_scan = services_requires_scan
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
@@ -44,7 +48,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": service_state["operating_state"],
|
||||
"health_status": service_state["health_state_visible"],
|
||||
"health_status": service_state["health_state_visible"]
|
||||
if self.services_requires_scan
|
||||
else service_state["health_state_actual"],
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -70,7 +76,9 @@ class ServiceObservation(AbstractObservation, discriminator="service"):
|
||||
:return: Constructed service observation instance.
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config.service_name])
|
||||
return cls(
|
||||
where=parent_where + ["services", config.service_name], services_requires_scan=config.services_requires_scan
|
||||
)
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
@@ -82,7 +90,12 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
applications_requires_scan: Optional[bool] = None
|
||||
"""
|
||||
If True, applications must be scanned to update the health state. If False, true state is always shown.
|
||||
"""
|
||||
|
||||
def __init__(self, where: WhereType, applications_requires_scan: bool, thresholds: Optional[Dict] = {}) -> None:
|
||||
"""
|
||||
Initialise an application observation instance.
|
||||
|
||||
@@ -92,25 +105,52 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.applications_requires_scan = applications_requires_scan
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
if thresholds.get("app_executions") is None:
|
||||
self.low_app_execution_threshold = 0
|
||||
self.med_app_execution_threshold = 5
|
||||
self.high_app_execution_threshold = 10
|
||||
else:
|
||||
self._set_application_execution_thresholds(
|
||||
thresholds=[
|
||||
thresholds.get("app_executions")["low"],
|
||||
thresholds.get("app_executions")["medium"],
|
||||
thresholds.get("app_executions")["high"],
|
||||
]
|
||||
)
|
||||
|
||||
def _set_application_execution_thresholds(self, thresholds: List[int]):
|
||||
"""
|
||||
Method that validates and then sets the application execution threshold.
|
||||
|
||||
:param: thresholds: The application execution threshold to validate and set.
|
||||
"""
|
||||
if self._validate_thresholds(
|
||||
thresholds=[
|
||||
thresholds[0],
|
||||
thresholds[1],
|
||||
thresholds[2],
|
||||
],
|
||||
threshold_identifier="app_executions",
|
||||
):
|
||||
self.low_app_execution_threshold = thresholds[0]
|
||||
self.med_app_execution_threshold = thresholds[1]
|
||||
self.high_app_execution_threshold = thresholds[2]
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
Represent number of application executions as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:param num_access: Number of application executions.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
if num_executions > self.high_app_execution_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
elif num_executions > self.med_app_execution_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
elif num_executions > self.low_app_execution_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
@@ -128,7 +168,9 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"health_status": application_state["health_state_visible"]
|
||||
if self.applications_requires_scan
|
||||
else application_state["health_state_actual"],
|
||||
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
|
||||
}
|
||||
|
||||
@@ -161,4 +203,8 @@ class ApplicationObservation(AbstractObservation, discriminator="application"):
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
return cls(
|
||||
where=parent_where + ["applications", config.application_name],
|
||||
applications_requires_scan=config.applications_requires_scan,
|
||||
thresholds=config.thresholds,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import DEFAULT_BANDWIDTH, getLogger
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
|
||||
from primaite.game.agent.observations import NICObservation
|
||||
from primaite.game.agent.rewards import SharedReward
|
||||
from primaite.game.science import graph_has_cycle, topological_sort
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
@@ -44,15 +45,15 @@ from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
SERVICE_TYPES_MAPPING = {
|
||||
"DNSClient": DNSClient,
|
||||
"DNSServer": DNSServer,
|
||||
"DatabaseService": DatabaseService,
|
||||
"WebServer": WebServer,
|
||||
"FTPClient": FTPClient,
|
||||
"FTPServer": FTPServer,
|
||||
"NTPClient": NTPClient,
|
||||
"NTPServer": NTPServer,
|
||||
"Terminal": Terminal,
|
||||
"dns-client": DNSClient,
|
||||
"dns-server": DNSServer,
|
||||
"database-service": DatabaseService,
|
||||
"web-server": WebServer,
|
||||
"ftp-client": FTPClient,
|
||||
"ftp-server": FTPServer,
|
||||
"ntp-client": NTPClient,
|
||||
"ntp-server": NTPServer,
|
||||
"terminal": Terminal,
|
||||
}
|
||||
"""List of available services that can be installed on nodes in the PrimAITE Simulation."""
|
||||
|
||||
@@ -68,6 +69,8 @@ class PrimaiteGameOptions(BaseModel):
|
||||
|
||||
seed: int = None
|
||||
"""Random number seed for RNGs."""
|
||||
generate_seed_value: bool = False
|
||||
"""Internally generated seed value."""
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[Port]
|
||||
@@ -175,6 +178,7 @@ class PrimaiteGame:
|
||||
parameters=parameters,
|
||||
request=request,
|
||||
response=response,
|
||||
observation=obs,
|
||||
)
|
||||
|
||||
def pre_timestep(self) -> None:
|
||||
@@ -263,6 +267,7 @@ class PrimaiteGame:
|
||||
node_sets_cfg = network_config.get("node_sets", [])
|
||||
# Set the NMNE capture config
|
||||
NetworkInterface.nmne_config = NMNEConfig(**network_config.get("nmne_config", {}))
|
||||
NICObservation.capture_nmne = NMNEConfig(**network_config.get("nmne_config", {})).capture_nmne
|
||||
|
||||
for node_cfg in nodes_cfg:
|
||||
n_type = node_cfg["type"]
|
||||
@@ -293,6 +298,7 @@ class PrimaiteGame:
|
||||
|
||||
if "users" in node_cfg and new_node.software_manager.software.get("user-manager"):
|
||||
user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa
|
||||
|
||||
for user_cfg in node_cfg["users"]:
|
||||
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)
|
||||
|
||||
@@ -407,6 +413,7 @@ class PrimaiteGame:
|
||||
agents_cfg = cfg.get("agents", [])
|
||||
|
||||
for agent_cfg in agents_cfg:
|
||||
agent_cfg = {**agent_cfg, "thresholds": game.options.thresholds}
|
||||
new_agent = AbstractAgent.from_config(agent_cfg)
|
||||
game.agents[agent_cfg["ref"]] = new_agent
|
||||
if isinstance(new_agent, ProxyAgent):
|
||||
|
||||
@@ -50,40 +50,22 @@
|
||||
"custom_c2_agent = \"\"\"\n",
|
||||
" - ref: CustomC2Agent\n",
|
||||
" team: RED\n",
|
||||
" type: ProxyAgent\n",
|
||||
" type: proxy-a.gent\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: web_server\n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Beacon\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Server\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 2\n",
|
||||
" max_nics_per_node: 8\n",
|
||||
" max_acl_rules: 10\n",
|
||||
" ip_list:\n",
|
||||
" - 192.168.1.21\n",
|
||||
" - 192.168.1.14\n",
|
||||
" wildcard_list:\n",
|
||||
" - 0.0.0.1\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: do_nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: node_application_install\n",
|
||||
" action: node-application-install\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" node_name: web_server\n",
|
||||
" application_name: c2-beacon\n",
|
||||
" 2:\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" action: configure-c2-beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" node_name: web_server\n",
|
||||
" config:\n",
|
||||
" c2_server_ip_address: 192.168.10.21\n",
|
||||
" keep_alive_frequency:\n",
|
||||
@@ -92,10 +74,10 @@
|
||||
" 3:\n",
|
||||
" action: node_application_execute\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_id: 0\n",
|
||||
" node_name: web_server\n",
|
||||
" application_name: c2-beacon\n",
|
||||
" 4:\n",
|
||||
" action: c2_server_terminal_command\n",
|
||||
" action: c2-server-terminal-command\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" ip_address:\n",
|
||||
@@ -111,14 +93,14 @@
|
||||
" 5:\n",
|
||||
" action: c2-server-ransomware-configure\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" node_name: client_1\n",
|
||||
" config:\n",
|
||||
" server_ip_address: 192.168.1.14\n",
|
||||
" payload: ENCRYPT\n",
|
||||
" 6:\n",
|
||||
" action: c2_server_data_exfiltrate\n",
|
||||
" action: c2-server-data-exfiltrate\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" node_name: client_1\n",
|
||||
" target_file_name: \"database.db\"\n",
|
||||
" target_folder_name: \"database\"\n",
|
||||
" exfiltration_folder_name: \"spoils\"\n",
|
||||
@@ -128,31 +110,27 @@
|
||||
" password: admin\n",
|
||||
"\n",
|
||||
" 7:\n",
|
||||
" action: c2_server_ransomware_launch\n",
|
||||
" action: c2-server-ransomware-launch\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" node_name: client_1\n",
|
||||
" 8:\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" action: configure-c2-beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" node_name: web_server\n",
|
||||
" config:\n",
|
||||
" c2_server_ip_address: 192.168.10.21\n",
|
||||
" keep_alive_frequency: 10\n",
|
||||
" masquerade_protocol: TCP\n",
|
||||
" masquerade_port: DNS\n",
|
||||
" 9:\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" action: configure-c2-beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" node_name: web_server\n",
|
||||
" config:\n",
|
||||
" c2_server_ip_address: 192.168.10.22\n",
|
||||
" keep_alive_frequency:\n",
|
||||
" masquerade_protocol:\n",
|
||||
" masquerade_port:\n",
|
||||
"\n",
|
||||
" reward_function:\n",
|
||||
" reward_components:\n",
|
||||
" - type: DUMMY\n",
|
||||
"\"\"\"\n",
|
||||
"c2_agent_yaml = yaml.safe_load(custom_c2_agent)"
|
||||
]
|
||||
@@ -225,7 +203,7 @@
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
" applications: \n",
|
||||
" - application_name: C2Beacon\n",
|
||||
" - application_name: c2-beacon\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
@@ -233,7 +211,7 @@
|
||||
" action: node_application_install \n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" application_name: c2-beacon\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
@@ -268,7 +246,7 @@
|
||||
" action_map:\n",
|
||||
" ...\n",
|
||||
" 2:\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" action: configure-c2-beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Node Index\n",
|
||||
" config: # Further information about these config options can be found at the bottom of this notebook.\n",
|
||||
@@ -286,7 +264,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(2)\n",
|
||||
"c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n",
|
||||
"c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n",
|
||||
"web_server.software_manager.show()\n",
|
||||
"c2_beacon.show()"
|
||||
]
|
||||
@@ -307,13 +285,13 @@
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: web_server\n",
|
||||
" applications: \n",
|
||||
" - application_name: C2Beacon\n",
|
||||
" - application_name: c2-beacon\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" ...\n",
|
||||
" 3:\n",
|
||||
" action: node_application_execute\n",
|
||||
" action: node-application-execute\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" application_id: 0\n",
|
||||
@@ -374,11 +352,11 @@
|
||||
" ...\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications: \n",
|
||||
" - application_name: C2Server\n",
|
||||
" - application_name: c2-server\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 4:\n",
|
||||
" action: C2_SERVER_TERMINAL_COMMAND\n",
|
||||
" action: c2-server-terminal-command\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" ip_address:\n",
|
||||
@@ -431,7 +409,7 @@
|
||||
" ...\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications: \n",
|
||||
" - application_name: C2Server\n",
|
||||
" - application_name: c2-server\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 5:\n",
|
||||
@@ -459,7 +437,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ransomware_script: RansomwareScript = web_server.software_manager.software[\"RansomwareScript\"]\n",
|
||||
"ransomware_script: RansomwareScript = web_server.software_manager.software[\"ransomware-script\"]\n",
|
||||
"web_server.software_manager.show()\n",
|
||||
"ransomware_script.show()"
|
||||
]
|
||||
@@ -483,11 +461,11 @@
|
||||
" ...\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications: \n",
|
||||
" - application_name: C2Server\n",
|
||||
" - application_name: c2-server\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 6:\n",
|
||||
" action: c2_server_data_exfiltrate\n",
|
||||
" action: c2-server-data-exfiltrate\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
" target_file_name: \"database.db\"\n",
|
||||
@@ -549,11 +527,11 @@
|
||||
" ...\n",
|
||||
" - node_name: client_1\n",
|
||||
" applications: \n",
|
||||
" - application_name: C2Server\n",
|
||||
" - application_name: c2-server\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 7:\n",
|
||||
" action: c2_server_ransomware_launch\n",
|
||||
" action: c2-server-ransomware-launch\n",
|
||||
" options:\n",
|
||||
" node_id: 1\n",
|
||||
"```\n"
|
||||
@@ -598,20 +576,20 @@
|
||||
"custom_blue_agent_yaml = \"\"\"\n",
|
||||
" - ref: defender\n",
|
||||
" team: BLUE\n",
|
||||
" type: ProxyAgent\n",
|
||||
" type: proxy-agent\n",
|
||||
"\n",
|
||||
" observation_space:\n",
|
||||
" type: CUSTOM\n",
|
||||
" type: custom\n",
|
||||
" options:\n",
|
||||
" components:\n",
|
||||
" - type: NODES\n",
|
||||
" - type: nodes\n",
|
||||
" label: NODES\n",
|
||||
" options:\n",
|
||||
" hosts:\n",
|
||||
" - hostname: web_server\n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Beacon\n",
|
||||
" - application_name: RansomwareScript\n",
|
||||
" - application_name: c2-beacon\n",
|
||||
" - application_name: ransomware-script\n",
|
||||
" folders:\n",
|
||||
" - folder_name: exfiltration_folder\n",
|
||||
" files:\n",
|
||||
@@ -661,7 +639,7 @@
|
||||
" - UDP\n",
|
||||
" num_rules: 10\n",
|
||||
"\n",
|
||||
" - type: LINKS\n",
|
||||
" - type: links\n",
|
||||
" label: LINKS\n",
|
||||
" options:\n",
|
||||
" link_references:\n",
|
||||
@@ -675,7 +653,7 @@
|
||||
" - switch_2:eth-1<->client_1:eth-1\n",
|
||||
" - switch_2:eth-2<->client_2:eth-1\n",
|
||||
" - switch_2:eth-7<->security_suite:eth-2\n",
|
||||
" - type: \"NONE\"\n",
|
||||
" - type: \"none\"\n",
|
||||
" label: ICS\n",
|
||||
" options: {}\n",
|
||||
"\n",
|
||||
@@ -685,16 +663,16 @@
|
||||
" action: do_nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: node_application_remove\n",
|
||||
" action: node-application-remove\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" node_name: web-server\n",
|
||||
" application_name: C2Beacon\n",
|
||||
" 2:\n",
|
||||
" action: node_shutdown\n",
|
||||
" action: node-shutdown\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" node_name: web-server\n",
|
||||
" 3:\n",
|
||||
" action: router_acl_add_rule\n",
|
||||
" action: router-acl-add-rule\n",
|
||||
" options:\n",
|
||||
" target_router: router_1\n",
|
||||
" position: 1\n",
|
||||
@@ -707,36 +685,6 @@
|
||||
" source_wildcard_id: 0\n",
|
||||
" dest_wildcard_id: 0\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: web_server\n",
|
||||
" applications:\n",
|
||||
" - application_name: C2Beacon\n",
|
||||
"\n",
|
||||
" - node_name: database_server\n",
|
||||
" folders:\n",
|
||||
" - folder_name: database\n",
|
||||
" files:\n",
|
||||
" - file_name: database.db\n",
|
||||
" services:\n",
|
||||
" - service_name: DatabaseService\n",
|
||||
" - node_name: router_1\n",
|
||||
"\n",
|
||||
" max_folders_per_node: 2\n",
|
||||
" max_files_per_folder: 2\n",
|
||||
" max_services_per_node: 2\n",
|
||||
" max_nics_per_node: 8\n",
|
||||
" max_acl_rules: 10\n",
|
||||
" ip_list:\n",
|
||||
" - 192.168.10.21\n",
|
||||
" - 192.168.1.12\n",
|
||||
" wildcard_list:\n",
|
||||
" - 0.0.0.1\n",
|
||||
" reward_function:\n",
|
||||
" reward_components:\n",
|
||||
" - type: DUMMY\n",
|
||||
"\n",
|
||||
" agent_settings:\n",
|
||||
" flatten_obs: False\n",
|
||||
"\"\"\"\n",
|
||||
@@ -875,7 +823,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Installing RansomwareScript via C2 Terminal Commands\n",
|
||||
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n",
|
||||
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n",
|
||||
" \"username\": \"admin\",\n",
|
||||
" \"password\": \"admin\"}\n",
|
||||
"c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)\n"
|
||||
@@ -1034,11 +982,11 @@
|
||||
" web_server: Server = given_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n",
|
||||
"\n",
|
||||
" client_1.software_manager.install(C2Server)\n",
|
||||
" c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
|
||||
" c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
|
||||
" c2_server.run()\n",
|
||||
"\n",
|
||||
" web_server.software_manager.install(C2Beacon)\n",
|
||||
" c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n",
|
||||
" c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n",
|
||||
" c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\")\n",
|
||||
" c2_beacon.establish()\n",
|
||||
"\n",
|
||||
@@ -1132,11 +1080,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Attempting to install the C2 RansomwareScript\n",
|
||||
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"]],\n",
|
||||
"ransomware_install_command = {\"commands\":[[\"software_manager\", \"application\", \"install\", \"ransomware-script\"]],\n",
|
||||
" \"username\": \"admin\",\n",
|
||||
" \"password\": \"admin\"}\n",
|
||||
"\n",
|
||||
"c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
|
||||
"c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
|
||||
"c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)"
|
||||
]
|
||||
},
|
||||
@@ -1220,11 +1168,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Attempting to install the C2 RansomwareScript\n",
|
||||
"ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"RansomwareScript\"],\n",
|
||||
"ransomware_install_command = {\"commands\":[\"software_manager\", \"application\", \"install\", \"ransomware-script\"],\n",
|
||||
" \"username\": \"admin\",\n",
|
||||
" \"password\": \"admin\"}\n",
|
||||
"\n",
|
||||
"c2_server: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
|
||||
"c2_server: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
|
||||
"c2_server.send_command(C2Command.TERMINAL, command_options=ransomware_install_command)"
|
||||
]
|
||||
},
|
||||
@@ -1345,7 +1293,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database_server\")\n",
|
||||
"database_server: Server = blue_env.game.simulation.network.get_node_by_hostname(\"database-server\")\n",
|
||||
"database_server.software_manager.file_system.show(full=True)"
|
||||
]
|
||||
},
|
||||
@@ -1391,7 +1339,7 @@
|
||||
"\n",
|
||||
"``` YAML\n",
|
||||
"...\n",
|
||||
" action: configure_c2_beacon\n",
|
||||
" action: configure-c2-beacon\n",
|
||||
" options:\n",
|
||||
" node_id: 0\n",
|
||||
" config:\n",
|
||||
@@ -1446,16 +1394,16 @@
|
||||
"source": [
|
||||
"web_server: Server = c2_config_env.game.simulation.network.get_node_by_hostname(\"web_server\")\n",
|
||||
"web_server.software_manager.install(C2Beacon)\n",
|
||||
"c2_beacon: C2Beacon = web_server.software_manager.software[\"C2Beacon\"]\n",
|
||||
"c2_beacon: C2Beacon = web_server.software_manager.software[\"c2-beacon\"]\n",
|
||||
"\n",
|
||||
"client_1: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
|
||||
"client_1.software_manager.install(C2Server)\n",
|
||||
"c2_server_1: C2Server = client_1.software_manager.software[\"C2Server\"]\n",
|
||||
"c2_server_1: C2Server = client_1.software_manager.software[\"c2-server\"]\n",
|
||||
"c2_server_1.run()\n",
|
||||
"\n",
|
||||
"client_2: Computer = c2_config_env.game.simulation.network.get_node_by_hostname(\"client_2\")\n",
|
||||
"client_2.software_manager.install(C2Server)\n",
|
||||
"c2_server_2: C2Server = client_2.software_manager.software[\"C2Server\"]\n",
|
||||
"c2_server_2: C2Server = client_2.software_manager.software[\"c2-server\"]\n",
|
||||
"c2_server_2.run()"
|
||||
]
|
||||
},
|
||||
@@ -1759,6 +1707,16 @@
|
||||
"\n",
|
||||
"display_obs_diffs(tcp_c2_obs, udp_c2_obs, blue_config_env.game.step_counter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"env.game.agents[\"CustomC2Agent\"].show_history()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
"source": [
|
||||
"def make_cfg_have_flat_obs(cfg):\n",
|
||||
" for agent in cfg['agents']:\n",
|
||||
" if agent['type'] == \"ProxyAgent\":\n",
|
||||
" if agent['type'] == \"proxy-agent\":\n",
|
||||
" agent['agent_settings']['flatten_obs'] = False"
|
||||
]
|
||||
},
|
||||
@@ -76,9 +76,9 @@
|
||||
" # parse the info dict form step output and write out what the red agent is doing\n",
|
||||
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
|
||||
" red_action = red_info.action\n",
|
||||
" if red_action == 'do_nothing':\n",
|
||||
" if red_action == 'do-nothing':\n",
|
||||
" red_str = 'DO NOTHING'\n",
|
||||
" elif red_action == 'node_application_execute':\n",
|
||||
" elif red_action == 'node-application-execute':\n",
|
||||
" client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n",
|
||||
" red_str = f\"ATTACK from {client}\"\n",
|
||||
" return red_str"
|
||||
@@ -147,36 +147,14 @@
|
||||
"```yaml\n",
|
||||
" - ref: data_manipulation_attacker # name of agent\n",
|
||||
" team: RED # not used, just for human reference\n",
|
||||
" type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n",
|
||||
" type: red-database-corrupting-agent # type of agent - this lets primaite know which agent class to use\n",
|
||||
"\n",
|
||||
" # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n",
|
||||
" observation_space:\n",
|
||||
" type: UC2RedObservation\n",
|
||||
" type: uc2-red-observation # TODO: what\n",
|
||||
" options:\n",
|
||||
" nodes: {}\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" \n",
|
||||
" # The agent has access to the DataManipulationBoth on clients 1 and 2.\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: client_1 # The network should have a node called client_1\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n",
|
||||
" - node_name: client_2 # The network should have a node called client_2\n",
|
||||
" applications:\n",
|
||||
" - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n",
|
||||
"\n",
|
||||
" # not important\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 1\n",
|
||||
"\n",
|
||||
" # red agent does not need a reward function\n",
|
||||
" reward_function:\n",
|
||||
" reward_components:\n",
|
||||
" - type: DUMMY\n",
|
||||
"\n",
|
||||
" # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n",
|
||||
" agent_settings:\n",
|
||||
" start_settings:\n",
|
||||
@@ -211,15 +189,13 @@
|
||||
" \n",
|
||||
" # \n",
|
||||
" applications:\n",
|
||||
" - ref: data_manipulation_bot\n",
|
||||
" type: DataManipulationBot\n",
|
||||
" - type: data-manipulation-bot\n",
|
||||
" options:\n",
|
||||
" port_scan_p_of_success: 0.8 # Probability that port scan is successful\n",
|
||||
" data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n",
|
||||
" payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n",
|
||||
" server_ip: 192.168.1.14 # IP address of server hosting the database\n",
|
||||
" - ref: client_1_database_client\n",
|
||||
" type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n",
|
||||
" - type: database-client # Database client must be installed in order for DataManipulationBot to function\n",
|
||||
" options:\n",
|
||||
" db_server_ip: 192.168.1.14 # IP address of server hosting the database\n",
|
||||
"```"
|
||||
@@ -354,19 +330,16 @@
|
||||
"# Make attack always succeed.\n",
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
" applications:\n",
|
||||
" - ref: data_manipulation_bot\n",
|
||||
" type: DataManipulationBot\n",
|
||||
" - type: data-manipulation-bot\n",
|
||||
" options:\n",
|
||||
" port_scan_p_of_success: 1.0\n",
|
||||
" data_manipulation_p_of_success: 1.0\n",
|
||||
" payload: \"DELETE\"\n",
|
||||
" server_ip: 192.168.1.14\n",
|
||||
" - ref: client_1_web_browser\n",
|
||||
" type: WebBrowser\n",
|
||||
" - type: web-browser\n",
|
||||
" options:\n",
|
||||
" target_url: http://arcd.com/users/\n",
|
||||
" - ref: client_1_database_client\n",
|
||||
" type: DatabaseClient\n",
|
||||
" - type: database-client\n",
|
||||
" options:\n",
|
||||
" db_server_ip: 192.168.1.14\n",
|
||||
"\"\"\")\n",
|
||||
@@ -399,19 +372,16 @@
|
||||
"# Make attack always fail.\n",
|
||||
"change = yaml.safe_load(\"\"\"\n",
|
||||
" applications:\n",
|
||||
" - ref: data_manipulation_bot\n",
|
||||
" type: DataManipulationBot\n",
|
||||
" - type: data-manipulation-bot\n",
|
||||
" options:\n",
|
||||
" port_scan_p_of_success: 0.0\n",
|
||||
" data_manipulation_p_of_success: 0.0\n",
|
||||
" payload: \"DELETE\"\n",
|
||||
" server_ip: 192.168.1.14\n",
|
||||
" - ref: client_1_web_browser\n",
|
||||
" type: WebBrowser\n",
|
||||
" - type: web-browser\n",
|
||||
" options:\n",
|
||||
" target_url: http://arcd.com/users/\n",
|
||||
" - ref: client_1_database_client\n",
|
||||
" type: DatabaseClient\n",
|
||||
" - type: database-client\n",
|
||||
" options:\n",
|
||||
" db_server_ip: 192.168.1.14\n",
|
||||
"\"\"\")\n",
|
||||
|
||||
@@ -684,6 +684,15 @@
|
||||
" print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.game.agents[\"data_manipulation_attacker\"].show_history()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -717,7 +726,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -153,6 +153,49 @@
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"enabled\"] = was_enabled\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"][\"output_sys_logs\"] = was_syslogs_enabled"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Viewing Agent history\n",
|
||||
"\n",
|
||||
"It's possible to view the actions carried out by an agent for a given training session using the `show_history()` method. By default, this will be all actions apart from DONOTHING actions."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)\n",
|
||||
"\n",
|
||||
"# Run the training session to generate some resultant data.\n",
|
||||
"for i in range(100):\n",
|
||||
" env.step(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `.show_history()` should show us when the Data Manipulation used the `NODE_APPLICATION_EXECUTE` action."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"attacker = env.game.agents[\"data_manipulation_attacker\"]\n",
|
||||
"\n",
|
||||
"attacker.show_history()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -171,7 +214,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
479
src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb
Normal file
479
src/primaite/notebooks/How-To-Use-Primaite-Dev-Mode.ipynb
Normal file
@@ -0,0 +1,479 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PrimAITE Developer mode\n",
|
||||
"\n",
|
||||
"PrimAITE has built in developer tools.\n",
|
||||
"\n",
|
||||
"The dev-mode is designed to help make the development of PrimAITE easier.\n",
|
||||
"\n",
|
||||
"`NOTE: For the purposes of the notebook, the commands are preceeded by \"!\". When running the commands, run it without the \"!\".`\n",
|
||||
"\n",
|
||||
"To display the available dev-mode options, run the command below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode --help"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save the current PRIMAITE_CONFIG to restore after the notebook runs\n",
|
||||
"\n",
|
||||
"from primaite import PRIMAITE_CONFIG\n",
|
||||
"\n",
|
||||
"temp_config = PRIMAITE_CONFIG.copy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Dev mode options"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### enable\n",
|
||||
"\n",
|
||||
"Enables the dev mode for PrimAITE.\n",
|
||||
"\n",
|
||||
"This will enable the developer mode for PrimAITE.\n",
|
||||
"\n",
|
||||
"By default, when developer mode is enabled, session logs will be generated in the PRIMAITE_ROOT/sessions folder unless configured to be generated in another location."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode enable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### disable\n",
|
||||
"\n",
|
||||
"Disables the dev mode for PrimAITE.\n",
|
||||
"\n",
|
||||
"This will disable the developer mode for PrimAITE."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode disable"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### show\n",
|
||||
"\n",
|
||||
"Shows if PrimAITE is running in dev mode or production mode.\n",
|
||||
"\n",
|
||||
"The command will also show the developer mode configuration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode show"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### config\n",
|
||||
"\n",
|
||||
"Configure the PrimAITE developer mode"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --help"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### path\n",
|
||||
"\n",
|
||||
"Set the path where generated session files will be output.\n",
|
||||
"\n",
|
||||
"By default, this value will be in PRIMAITE_ROOT/sessions.\n",
|
||||
"\n",
|
||||
"To reset the path to default, run:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config path -root\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config path --default"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### --sys-log-level or -slevel\n",
|
||||
"\n",
|
||||
"Set the system log level.\n",
|
||||
"\n",
|
||||
"This will override the system log level in configurations and will make PrimAITE include the set log level and above.\n",
|
||||
"\n",
|
||||
"Available options are:\n",
|
||||
"- `DEBUG`\n",
|
||||
"- `INFO`\n",
|
||||
"- `WARNING`\n",
|
||||
"- `ERROR`\n",
|
||||
"- `CRITICAL`\n",
|
||||
"\n",
|
||||
"Default value is `DEBUG`\n",
|
||||
"\n",
|
||||
"Example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --sys-log-level DEBUG\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -slevel DEBUG"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### --agent-log-level or -alevel\n",
|
||||
"\n",
|
||||
"Set the agent log level.\n",
|
||||
"\n",
|
||||
"This will override the agent log level in configurations and will make PrimAITE include the set log level and above.\n",
|
||||
"\n",
|
||||
"Available options are:\n",
|
||||
"- `DEBUG`\n",
|
||||
"- `INFO`\n",
|
||||
"- `WARNING`\n",
|
||||
"- `ERROR`\n",
|
||||
"- `CRITICAL`\n",
|
||||
"\n",
|
||||
"Default value is `DEBUG`\n",
|
||||
"\n",
|
||||
"Example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --agent-log-level DEBUG\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -alevel DEBUG"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### --output-sys-logs or -sys\n",
|
||||
"\n",
|
||||
"If enabled, developer mode will output system logs.\n",
|
||||
"\n",
|
||||
"Example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --output-sys-logs\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -sys"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To disable outputting sys logs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --no-sys-logs\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -nsys"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### --output-agent-logs or -agent\n",
|
||||
"\n",
|
||||
"If enabled, developer mode will output agent action logs.\n",
|
||||
"\n",
|
||||
"Example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --output-agent-logs\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To disable outputting agent action logs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --no-agent-logs\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -nagent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### --output-pcap-logs or -pcap\n",
|
||||
"\n",
|
||||
"If enabled, developer mode will output PCAP logs.\n",
|
||||
"\n",
|
||||
"Example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --output-pcap-logs\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -pcap"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To disable outputting PCAP logs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --no-pcap-logs\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -npcap"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### --output-to-terminal or -t\n",
|
||||
"\n",
|
||||
"If enabled, developer mode will output logs to the terminal.\n",
|
||||
"\n",
|
||||
"Example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --output-to-terminal\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -t"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To disable terminal outputs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config --no-terminal\n",
|
||||
"\n",
|
||||
"# or\n",
|
||||
"\n",
|
||||
"!primaite dev-mode config -nt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Combining commands\n",
|
||||
"\n",
|
||||
"It is possible to combine commands to set the configuration.\n",
|
||||
"\n",
|
||||
"This saves having to enter multiple commands and allows for a much more efficient setting of PrimAITE developer mode configurations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Example of setting system log level and enabling the system logging:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config -slevel WARNING -sys"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Another example where the system log and agent action log levels are set and enabled and should be printed to terminal:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite dev-mode config -slevel ERROR -sys -alevel ERROR -agent -t"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Restore PRIMAITE_CONFIG\n",
|
||||
"from primaite.utils.cli.primaite_config_utils import update_primaite_application_config\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"global PRIMAITE_CONFIG\n",
|
||||
"PRIMAITE_CONFIG[\"developer_mode\"] = temp_config[\"developer_mode\"]\n",
|
||||
"update_primaite_application_config(config=PRIMAITE_CONFIG)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -114,7 +114,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(f\"DNS Client state: {client.software_manager.software.get('DNSClient').operating_state.name}\")"
|
||||
"print(f\"DNS Client state: {client.software_manager.software.get('dns-client').operating_state.name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,6 +9,13 @@
|
||||
"© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simulation Layer Implementation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -67,9 +74,9 @@
|
||||
"source": [
|
||||
"network: Network = basic_network()\n",
|
||||
"computer_a: Computer = network.get_node_by_hostname(\"node_a\")\n",
|
||||
"terminal_a: Terminal = computer_a.software_manager.software.get(\"Terminal\")\n",
|
||||
"terminal_a: Terminal = computer_a.software_manager.software.get(\"terminal\")\n",
|
||||
"computer_b: Computer = network.get_node_by_hostname(\"node_b\")\n",
|
||||
"terminal_b: Terminal = computer_b.software_manager.software.get(\"Terminal\")"
|
||||
"terminal_b: Terminal = computer_b.software_manager.software.get(\"terminal\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -121,7 +128,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"RansomwareScript\"])"
|
||||
"term_a_term_b_remote_connection.execute([\"software_manager\", \"application\", \"install\", \"ransomware-script\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -169,6 +176,22 @@
|
||||
"computer_b.file_system.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Information about the latest response when executing a remote command can be seen by calling the `last_response` attribute within `Terminal`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(terminal_a.last_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -207,6 +230,263 @@
|
||||
"source": [
|
||||
"computer_b.user_session_manager.show(include_historic=True, include_session_id=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Game Layer Implementation\n",
|
||||
"\n",
|
||||
"This notebook section will detail the implementation of how the game layer utilises the terminal to support different agent actions.\n",
|
||||
"\n",
|
||||
"The ``Terminal`` is used in a variety of different ways in the game layer. Specifically, the terminal is leveraged to implement the following actions:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"| Game Layer Action | Simulation Layer |\n",
|
||||
"|-----------------------------------|--------------------------|\n",
|
||||
"| ``node-send-local-command`` | Uses the given user credentials, creates a ``LocalTerminalSession`` and executes the given command and returns the ``RequestResponse``.\n",
|
||||
"| ``node-session-remote-login`` | Uses the given user credentials and remote IP to create a ``RemoteTerminalSession``.\n",
|
||||
"| ``node-send-remote-command`` | Uses the given remote IP to locate the correct ``RemoteTerminalSession``, executes the given command and returns the ``RequestsResponse``."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Game Layer Setup\n",
|
||||
"\n",
|
||||
"Similar to other notebooks, the next code cells create a custom proxy agent to demonstrate how these commands can be leveraged by agents in the ``UC2`` network environment.\n",
|
||||
"\n",
|
||||
"If you're unfamiliar with ``UC2`` then please refer to the [UC2-E2E-Demo notebook for further reference](./Data-Manipulation-E2E-Demonstration.ipynb)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"from primaite.config.load import data_manipulation_config_path\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"custom_terminal_agent = \"\"\"\n",
|
||||
" - ref: CustomC2Agent\n",
|
||||
" team: RED\n",
|
||||
" type: proxy-agent\n",
|
||||
" observation_space: null\n",
|
||||
" action_space:\n",
|
||||
" options:\n",
|
||||
" nodes:\n",
|
||||
" - node_name: client_1\n",
|
||||
" max_folders_per_node: 1\n",
|
||||
" max_files_per_folder: 1\n",
|
||||
" max_services_per_node: 2\n",
|
||||
" max_nics_per_node: 8\n",
|
||||
" max_acl_rules: 10\n",
|
||||
" ip_list:\n",
|
||||
" - 192.168.1.21\n",
|
||||
" - 192.168.1.14\n",
|
||||
" wildcard_list:\n",
|
||||
" - 0.0.0.1\n",
|
||||
" action_map:\n",
|
||||
" 0:\n",
|
||||
" action: do-nothing\n",
|
||||
" options: {}\n",
|
||||
" 1:\n",
|
||||
" action: node-send-local-command\n",
|
||||
" options:\n",
|
||||
" node_name: client_1\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" command:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - dog.png\n",
|
||||
" - False\n",
|
||||
" 2:\n",
|
||||
" action: node-session-remote-login\n",
|
||||
" options:\n",
|
||||
" node_name: client_1\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" remote_ip: 192.168.10.22\n",
|
||||
" 3:\n",
|
||||
" action: node-send-remote-command\n",
|
||||
" options:\n",
|
||||
" node_name: client_1\n",
|
||||
" remote_ip: 192.168.10.22\n",
|
||||
" command:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - cat.png\n",
|
||||
" - False\n",
|
||||
"\"\"\"\n",
|
||||
"custom_terminal_agent_yaml = yaml.safe_load(custom_terminal_agent)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path()) as f:\n",
|
||||
" cfg = yaml.safe_load(f)\n",
|
||||
" # removing all agents & adding the custom agent.\n",
|
||||
" cfg['agents'] = {}\n",
|
||||
" cfg['agents'] = custom_terminal_agent_yaml\n",
|
||||
"\n",
|
||||
"env = PrimaiteGymEnv(env_config=cfg)\n",
|
||||
"\n",
|
||||
"client_1: Computer = env.game.simulation.network.get_node_by_hostname(\"client_1\")\n",
|
||||
"client_2: Computer = env.game.simulation.network.get_node_by_hostname(\"client_2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Terminal Action | ``node-send-local-command`` \n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: node-send-local-command\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: client_1\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 1:\n",
|
||||
" action: node-send-local-command\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" command:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - dog.png\n",
|
||||
" - False\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(1)\n",
|
||||
"client_1.file_system.show(full=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Terminal Action | ``node-session-remote-login`` \n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: node-session-remote-login\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: client_1\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 2:\n",
|
||||
" action: node-session-remote-login\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" username: admin\n",
|
||||
" password: admin\n",
|
||||
" remote_ip: 192.168.10.22 # client_2's ip address.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(2)\n",
|
||||
"client_2.session_manager.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Terminal Action | ``node-send-remote-command``\n",
|
||||
"\n",
|
||||
"The yaml snippet below shows all the relevant agent options for this action:\n",
|
||||
"\n",
|
||||
"```yaml\n",
|
||||
"\n",
|
||||
" action_space:\n",
|
||||
" action_list:\n",
|
||||
" ...\n",
|
||||
" - type: node-send-remote-command\n",
|
||||
" ...\n",
|
||||
" options:\n",
|
||||
" nodes: # Node List\n",
|
||||
" - node_name: client_1\n",
|
||||
" ...\n",
|
||||
" ...\n",
|
||||
" action_map:\n",
|
||||
" 1:\n",
|
||||
" action: node-send-remote-command\n",
|
||||
" options:\n",
|
||||
" node_id: 0 # Index 0 at the node list.\n",
|
||||
" remote_ip: 192.168.10.22\n",
|
||||
" commands:\n",
|
||||
" - file_system\n",
|
||||
" - create\n",
|
||||
" - file\n",
|
||||
" - downloads\n",
|
||||
" - cat.png\n",
|
||||
" - False\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.step(3)\n",
|
||||
"client_2.file_system.show(full=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -26,14 +26,26 @@ except ModuleNotFoundError:
|
||||
_LOGGER.debug("Torch not available for importing")
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> Union[None, int]:
|
||||
def set_random_seed(seed: int, generate_seed_value: bool) -> Union[None, int]:
|
||||
"""
|
||||
Set random number generators.
|
||||
|
||||
If seed is None or -1 and generate_seed_value is True randomly generate a
|
||||
seed value.
|
||||
If seed is > -1 and generate_seed_value is True ignore the latter and use
|
||||
the provide seed value.
|
||||
|
||||
:param seed: int
|
||||
:param generate_seed_value: bool
|
||||
:return: None or the int representing the seed used.
|
||||
"""
|
||||
if seed is None or seed == -1:
|
||||
return None
|
||||
if generate_seed_value:
|
||||
rng = np.random.default_rng()
|
||||
# 2**32-1 is highest value for python RNG seed.
|
||||
seed = int(rng.integers(low=0, high=2**32 - 1))
|
||||
else:
|
||||
return None
|
||||
elif seed < -1:
|
||||
raise ValueError("Invalid random number seed")
|
||||
# Seed python RNG
|
||||
@@ -50,6 +62,13 @@ def set_random_seed(seed: int) -> Union[None, int]:
|
||||
return seed
|
||||
|
||||
|
||||
def log_seed_value(seed: int):
|
||||
"""Log the selected seed value to file."""
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "w") as file:
|
||||
file.write(f"Seed value = {seed}")
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
@@ -65,7 +84,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.seed = self.episode_scheduler(0).get("game", {}).get("seed")
|
||||
"""Get RNG seed from config file. NB: Must be before game instantiation."""
|
||||
self.seed = set_random_seed(self.seed)
|
||||
self.generate_seed_value = self.episode_scheduler(0).get("game", {}).get("generate_seed_value")
|
||||
self.seed = set_random_seed(self.seed, self.generate_seed_value)
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
@@ -79,6 +99,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
_LOGGER.info(f"PrimaiteGymEnv RNG seed = {self.seed}")
|
||||
|
||||
log_seed_value(self.seed)
|
||||
|
||||
def action_masks(self) -> np.ndarray:
|
||||
"""
|
||||
Return the action mask for the agent.
|
||||
@@ -146,7 +168,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
set_random_seed(seed, self.generate_seed_value)
|
||||
self.total_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
|
||||
|
||||
if self.io.settings.save_agent_actions:
|
||||
|
||||
@@ -864,7 +864,21 @@ class UserManager(Service, discriminator="user-manager"):
|
||||
"""
|
||||
rm = super()._init_request_manager()
|
||||
|
||||
# todo add doc about requeest schemas
|
||||
# todo add doc about request schemas
|
||||
rm.add_request(
|
||||
"add_user",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(
|
||||
self.add_user(username=request[0], password=request[1], is_admin=request[2])
|
||||
)
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"disable_user",
|
||||
RequestType(
|
||||
func=lambda request, context: RequestResponse.from_bool(self.disable_user(username=request[0]))
|
||||
),
|
||||
)
|
||||
rm.add_request(
|
||||
"change_password",
|
||||
RequestType(
|
||||
@@ -1572,7 +1586,7 @@ class Node(SimComponent, ABC):
|
||||
|
||||
operating_state: Any = None
|
||||
|
||||
users: Any = None # Temporary to appease "extra=forbid"
|
||||
users: List[Dict] = [] # Temporary to appease "extra=forbid"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: Node.ConfigSchema())
|
||||
"""Configuration items within Node"""
|
||||
@@ -1638,6 +1652,8 @@ class Node(SimComponent, ABC):
|
||||
self._install_system_software()
|
||||
self.session_manager.node = self
|
||||
self.session_manager.software_manager = self.software_manager
|
||||
for user in self.config.users:
|
||||
self.user_manager.add_user(**user, bypass_can_perform_action=True)
|
||||
|
||||
@property
|
||||
def user_manager(self) -> Optional[UserManager]:
|
||||
@@ -1769,7 +1785,7 @@ class Node(SimComponent, ABC):
|
||||
"""
|
||||
application_name = request[0]
|
||||
if self.software_manager.software.get(application_name):
|
||||
self.sys_log.warning(f"Can't install {application_name}. It's already installed.")
|
||||
self.sys_log.info(f"Can't install {application_name}. It's already installed.")
|
||||
return RequestResponse(status="success", data={"reason": "already installed"})
|
||||
application_class = Application._registry[application_name]
|
||||
self.software_manager.install(application_class)
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, ClassVar, Dict, Literal, Optional
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.network.hardware.base import (
|
||||
IPWiredNetworkInterface,
|
||||
Link,
|
||||
@@ -313,7 +314,7 @@ class HostNode(Node, discriminator="host-node"):
|
||||
"""
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"HostARP": HostARP,
|
||||
"host-arp": HostARP,
|
||||
"icmp": ICMP,
|
||||
"dns-client": DNSClient,
|
||||
"ntp-client": NTPClient,
|
||||
@@ -339,7 +340,7 @@ class HostNode(Node, discriminator="host-node"):
|
||||
ip_address: IPV4Address
|
||||
services: Any = None # temporarily unset to appease extra="forbid"
|
||||
applications: Any = None # temporarily unset to appease extra="forbid"
|
||||
folders: Any = None # temporarily unset to appease extra="forbid"
|
||||
folders: List[Dict] = {} # temporarily unset to appease extra="forbid"
|
||||
network_interfaces: Any = None # temporarily unset to appease extra="forbid"
|
||||
|
||||
config: ConfigSchema = Field(default_factory=lambda: HostNode.ConfigSchema())
|
||||
@@ -348,6 +349,18 @@ class HostNode(Node, discriminator="host-node"):
|
||||
super().__init__(**kwargs)
|
||||
self.connect_nic(NIC(ip_address=kwargs["config"].ip_address, subnet_mask=kwargs["config"].subnet_mask))
|
||||
|
||||
for folder in self.config.folders:
|
||||
# handle empty foler defined by just a string
|
||||
self.file_system.create_folder(folder["folder_name"])
|
||||
|
||||
for file in folder.get("files", []):
|
||||
self.file_system.create_file(
|
||||
folder_name=folder["folder_name"],
|
||||
file_name=file["file_name"],
|
||||
size=file.get("size", 0),
|
||||
file_type=FileType[file.get("type", "UNKNOWN").upper()],
|
||||
)
|
||||
|
||||
@property
|
||||
def nmap(self) -> Optional[NMAP]:
|
||||
"""
|
||||
|
||||
@@ -49,7 +49,7 @@ class Firewall(Router, discriminator="firewall"):
|
||||
|
||||
Example:
|
||||
>>> from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
>>> from primaite.simulator.network.transmission.transport_layer import Port
|
||||
>>> from primaite.utils.validation.port import Port
|
||||
>>> firewall = Firewall(hostname="Firewall1")
|
||||
>>> firewall.configure_internal_port(ip_address="192.168.1.1", subnet_mask="255.255.255.0")
|
||||
>>> firewall.configure_external_port(ip_address="10.0.0.1", subnet_mask="255.255.255.0")
|
||||
|
||||
@@ -467,6 +467,7 @@ class AccessControlList(SimComponent):
|
||||
"""Check if a packet with the given properties is permitted through the ACL."""
|
||||
permitted = False
|
||||
rule: ACLRule = None
|
||||
|
||||
for _rule in self._acl:
|
||||
if not _rule:
|
||||
continue
|
||||
@@ -1215,9 +1216,9 @@ class Router(NetworkNode, discriminator="router"):
|
||||
config: ConfigSchema = Field(default_factory=lambda: Router.ConfigSchema())
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {
|
||||
"UserSessionManager": UserSessionManager,
|
||||
"UserManager": UserManager,
|
||||
"Terminal": Terminal,
|
||||
"user-session-manager": UserSessionManager,
|
||||
"user-manager": UserManager,
|
||||
"terminal": Terminal,
|
||||
}
|
||||
|
||||
network_interfaces: Dict[str, RouterInterface] = {}
|
||||
@@ -1385,6 +1386,12 @@ class Router(NetworkNode, discriminator="router"):
|
||||
|
||||
return False
|
||||
|
||||
def subject_to_acl(self, frame: Frame) -> bool:
|
||||
"""Check that frame is subject to ACL rules."""
|
||||
if frame.ip.protocol == "udp" and frame.is_arp:
|
||||
return False
|
||||
return True
|
||||
|
||||
def receive_frame(self, frame: Frame, from_network_interface: RouterInterface):
|
||||
"""
|
||||
Processes an incoming frame received on one of the router's interfaces.
|
||||
@@ -1398,8 +1405,12 @@ class Router(NetworkNode, discriminator="router"):
|
||||
if self.operating_state != NodeOperatingState.ON:
|
||||
return
|
||||
|
||||
# Check if it's permitted
|
||||
permitted, rule = self.acl.is_permitted(frame)
|
||||
if self.subject_to_acl(frame=frame):
|
||||
# Check if it's permitted
|
||||
permitted, rule = self.acl.is_permitted(frame)
|
||||
else:
|
||||
permitted = True
|
||||
rule = None
|
||||
|
||||
if not permitted:
|
||||
at_port = self._get_port_of_nic(from_network_interface)
|
||||
|
||||
@@ -163,7 +163,7 @@ class Frame(BaseModel):
|
||||
"""
|
||||
Checks if the Frame is an ARP (Address Resolution Protocol) packet.
|
||||
|
||||
This is determined by checking if the destination port of the TCP header is equal to the ARP port.
|
||||
This is determined by checking if the destination and source port of the UDP header is equal to the ARP port.
|
||||
|
||||
:return: True if the Frame is an ARP packet, otherwise False.
|
||||
"""
|
||||
|
||||
@@ -55,7 +55,7 @@ class ARP(Service, discriminator="arp"):
|
||||
|
||||
:param markdown: If True, format the output as Markdown. Otherwise, use plain text.
|
||||
"""
|
||||
table = PrettyTable(["IP Address", "MAC Address", "Via"])
|
||||
table = PrettyTable(["IP Address", "MAC Address", "Via", "Port"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
@@ -66,6 +66,7 @@ class ARP(Service, discriminator="arp"):
|
||||
str(ip),
|
||||
arp.mac_address,
|
||||
self.software_manager.node.network_interfaces[arp.network_interface_uuid].mac_address,
|
||||
self.software_manager.node.network_interfaces[arp.network_interface_uuid].port_num,
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
|
||||
@@ -142,12 +142,20 @@ class Terminal(Service, discriminator="terminal"):
|
||||
_client_connection_requests: Dict[str, Optional[Union[str, TerminalClientConnection]]] = {}
|
||||
"""Dictionary of connect requests made to remote nodes."""
|
||||
|
||||
_last_response: Optional[RequestResponse] = None
|
||||
"""Last response received from RequestManager, for returning remote RequestResponse."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "terminal"
|
||||
kwargs["port"] = PORT_LOOKUP["SSH"]
|
||||
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def last_response(self) -> Optional[RequestResponse]:
|
||||
"""Public version of _last_response attribute."""
|
||||
return self._last_response
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
Produce a dictionary describing the current state of this object.
|
||||
@@ -186,7 +194,7 @@ class Terminal(Service, discriminator="terminal"):
|
||||
return RequestResponse(status="failure", data={})
|
||||
|
||||
rm.add_request(
|
||||
"node-session-remote-login",
|
||||
"node_session_remote_login",
|
||||
request_type=RequestType(func=_remote_login),
|
||||
)
|
||||
|
||||
@@ -209,28 +217,45 @@ class Terminal(Service, discriminator="terminal"):
|
||||
command: str = request[1]["command"]
|
||||
remote_connection = self._get_connection_from_ip(ip_address=ip_address)
|
||||
if remote_connection:
|
||||
outcome = remote_connection.execute(command)
|
||||
if outcome:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={},
|
||||
)
|
||||
else:
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={},
|
||||
)
|
||||
remote_connection.execute(command)
|
||||
return self.last_response if not None else RequestResponse(status="failure", data={})
|
||||
return RequestResponse(
|
||||
status="failure",
|
||||
data={"reason": "Failed to execute command."},
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"send_remote_command",
|
||||
request_type=RequestType(func=remote_execute_request),
|
||||
)
|
||||
|
||||
def local_execute_request(request: RequestFormat, context: Dict) -> RequestResponse:
|
||||
"""Executes a command using a local terminal session."""
|
||||
command: str = request[2]["command"]
|
||||
local_connection = self._process_local_login(username=request[0], password=request[1])
|
||||
if local_connection:
|
||||
outcome = local_connection.execute(command)
|
||||
if outcome:
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={"reason": outcome},
|
||||
)
|
||||
return RequestResponse(
|
||||
status="success",
|
||||
data={"reason": "Local Terminal failed to resolve command. Potentially invalid credentials?"},
|
||||
)
|
||||
|
||||
rm.add_request(
|
||||
"send_local_command",
|
||||
request_type=RequestType(func=local_execute_request),
|
||||
)
|
||||
|
||||
return rm
|
||||
|
||||
def execute(self, command: List[Any]) -> Optional[RequestResponse]:
|
||||
"""Execute a passed ssh command via the request manager."""
|
||||
return self.parent.apply_request(command)
|
||||
self._last_response = self.parent.apply_request(command)
|
||||
return self._last_response
|
||||
|
||||
def _get_connection_from_ip(self, ip_address: IPv4Address) -> Optional[RemoteTerminalConnection]:
|
||||
"""Find Remote Terminal Connection from a given IP."""
|
||||
@@ -409,6 +434,8 @@ class Terminal(Service, discriminator="terminal"):
|
||||
"""
|
||||
source_ip = kwargs["frame"].ip.src_ip_address
|
||||
self.sys_log.info(f"{self.name}: Received payload: {payload}. Source: {source_ip}")
|
||||
self._last_response = None # Clear last response
|
||||
|
||||
if isinstance(payload, SSHPacket):
|
||||
if payload.transport_message == SSHTransportMessage.SSH_MSG_USERAUTH_REQUEST:
|
||||
# validate & add connection
|
||||
@@ -457,6 +484,9 @@ class Terminal(Service, discriminator="terminal"):
|
||||
session_id=session_id,
|
||||
source_ip=source_ip,
|
||||
)
|
||||
self._last_response: RequestResponse = RequestResponse(
|
||||
status="success", data={"reason": "Login Successful"}
|
||||
)
|
||||
|
||||
elif payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_REQUEST:
|
||||
# Requesting a command to be executed
|
||||
@@ -468,12 +498,32 @@ class Terminal(Service, discriminator="terminal"):
|
||||
payload.connection_uuid
|
||||
)
|
||||
remote_session.last_active_step = self.software_manager.node.user_session_manager.current_timestep
|
||||
self.execute(command)
|
||||
self._last_response: RequestResponse = self.execute(command)
|
||||
|
||||
if self._last_response.status == "success":
|
||||
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
|
||||
else:
|
||||
transport_message = SSHTransportMessage.SSH_MSG_SERVICE_FAILED
|
||||
|
||||
payload: SSHPacket = SSHPacket(
|
||||
payload=self._last_response,
|
||||
transport_message=transport_message,
|
||||
connection_message=SSHConnectionMessage.SSH_MSG_CHANNEL_DATA,
|
||||
)
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_port=self.port, session_id=session_id
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connection UUID:{payload.connection_uuid} is not valid. Rejecting Command."
|
||||
)
|
||||
elif (
|
||||
payload.transport_message == SSHTransportMessage.SSH_MSG_SERVICE_SUCCESS
|
||||
or SSHTransportMessage.SSH_MSG_SERVICE_FAILED
|
||||
):
|
||||
# Likely receiving command ack from remote.
|
||||
self._last_response = payload.payload
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "disconnect":
|
||||
|
||||
@@ -117,37 +117,44 @@ class WebServer(Service, discriminator="web-server"):
|
||||
:type: payload: HttpRequestPacket
|
||||
"""
|
||||
response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND, payload=payload)
|
||||
try:
|
||||
parsed_url = urlparse(payload.request_url)
|
||||
path = parsed_url.path.strip("/")
|
||||
|
||||
if len(path) < 1:
|
||||
parsed_url = urlparse(payload.request_url)
|
||||
path = parsed_url.path.strip("/") if parsed_url and parsed_url.path else ""
|
||||
|
||||
if len(path) < 1:
|
||||
# query succeeded
|
||||
response.status_code = HttpStatusCode.OK
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
# get all users
|
||||
if not self._establish_db_connection():
|
||||
# unable to create a db connection
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
if self.db_connection.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
return response
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
# get all users
|
||||
if not self.db_connection:
|
||||
self._establish_db_connection()
|
||||
|
||||
if self.db_connection.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
else:
|
||||
self.set_health_state(SoftwareHealthState.COMPROMISED)
|
||||
|
||||
return response
|
||||
except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 )
|
||||
# something went wrong on the server
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
def _establish_db_connection(self) -> None:
|
||||
def _establish_db_connection(self) -> bool:
|
||||
"""Establish a connection to db."""
|
||||
# if active db connection, return true
|
||||
if self.db_connection:
|
||||
return True
|
||||
|
||||
# otherwise, try to create db connection
|
||||
db_client = self.software_manager.software.get("database-client")
|
||||
|
||||
if db_client is None:
|
||||
return False # database client not installed
|
||||
|
||||
self.db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
return self.db_connection is not None
|
||||
|
||||
def send(
|
||||
self,
|
||||
|
||||
@@ -25,7 +25,19 @@ game:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 100
|
||||
medium: 25
|
||||
low: 5
|
||||
file_access:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 2
|
||||
app_executions:
|
||||
high: 5
|
||||
medium: 3
|
||||
low: 2
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
@@ -64,10 +76,16 @@ agents:
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
applications:
|
||||
- application_name: WebBrowser
|
||||
folders:
|
||||
- folder_name: root
|
||||
files:
|
||||
- file_name: "test.txt"
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_applications: 1
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
@@ -182,6 +200,10 @@ simulation:
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
- type: ntp-server
|
||||
folders:
|
||||
- folder_name: root
|
||||
files:
|
||||
- file_name: test.txt
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
|
||||
226
tests/assets/configs/nodes_with_initial_files.yaml
Normal file
226
tests/assets/configs/nodes_with_initial_files.yaml
Normal file
@@ -0,0 +1,226 @@
|
||||
# Basic Switched network
|
||||
#
|
||||
# -------------- -------------- --------------
|
||||
# | client_1 |------| switch_1 |------| client_2 |
|
||||
# -------------- -------------- --------------
|
||||
#
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
agent_log_level: INFO
|
||||
save_agent_logs: true
|
||||
write_agent_log_to_terminal: True
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: periodic-agent
|
||||
action_space:
|
||||
action_map:
|
||||
0:
|
||||
action: do-nothing
|
||||
options: {}
|
||||
1:
|
||||
action: node-application-execute
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
agent_settings:
|
||||
possible_start_nodes: [client_2,]
|
||||
target_application: web-browser
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: proxy-agent
|
||||
|
||||
observation_space:
|
||||
type: custom
|
||||
options:
|
||||
components:
|
||||
- type: nodes
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
include_nmne: false
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: links
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- switch_1:eth-1<->client_1:eth-1
|
||||
- switch_1:eth-2<->client_2:eth-1
|
||||
- type: none
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_map:
|
||||
0:
|
||||
action: do-nothing
|
||||
options: {}
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: database-file-integrity
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
- type: web-server-404-penalty
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- hostname: client_1
|
||||
type: computer
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: ransomware-script
|
||||
- type: web-browser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
- type: database-client
|
||||
options:
|
||||
db_server_ip: 192.168.1.10
|
||||
server_password: arcd
|
||||
- type: data-manipulation-bot
|
||||
options:
|
||||
port_scan_p_of_success: 0.8
|
||||
data_manipulation_p_of_success: 0.8
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.21
|
||||
server_password: arcd
|
||||
- type: dos-bot
|
||||
options:
|
||||
target_ip_address: 192.168.10.21
|
||||
payload: SPOOF DATA
|
||||
port_scan_p_of_success: 0.8
|
||||
services:
|
||||
- type: dns-client
|
||||
options:
|
||||
dns_server: 192.168.1.10
|
||||
- type: dns-server
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.10
|
||||
- type: database-service
|
||||
options:
|
||||
backup_server_ip: 192.168.1.10
|
||||
- type: web-server
|
||||
- type: ftp-server
|
||||
options:
|
||||
server_password: arcd
|
||||
- type: ntp-client
|
||||
options:
|
||||
ntp_server_ip: 192.168.1.10
|
||||
- type: ntp-server
|
||||
- hostname: client_2
|
||||
type: computer
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
folders:
|
||||
- folder_name: empty_folder
|
||||
- folder_name: downloads
|
||||
files:
|
||||
- file_name: "test.txt"
|
||||
- file_name: "another_file.pwtwoti"
|
||||
- folder_name: root
|
||||
files:
|
||||
- file_name: passwords
|
||||
size: 663
|
||||
type: TXT
|
||||
# pre installed services and applications
|
||||
- hostname: client_3
|
||||
type: computer
|
||||
ip_address: 192.168.10.23
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
start_up_duration: 0
|
||||
shut_down_duration: 0
|
||||
operating_state: "OFF"
|
||||
# pre installed services and applications
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
bandwidth: 200
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
bandwidth: 200
|
||||
@@ -8,7 +8,7 @@ from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
BASIC_SWITCHED_NETWORK_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
@@ -24,3 +24,42 @@ def test_thresholds():
|
||||
game = load_config(data_manipulation_config_path())
|
||||
|
||||
assert game.options.thresholds is not None
|
||||
|
||||
|
||||
def test_nmne_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["nmne"] is not None
|
||||
|
||||
# get NIC observation
|
||||
nic_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].nics[0]
|
||||
assert nic_obs.low_nmne_threshold == 5
|
||||
assert nic_obs.med_nmne_threshold == 25
|
||||
assert nic_obs.high_nmne_threshold == 100
|
||||
|
||||
|
||||
def test_file_access_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["file_access"] is not None
|
||||
|
||||
# get file observation
|
||||
file_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].folders[0].files[0]
|
||||
assert file_obs.low_file_access_threshold == 2
|
||||
assert file_obs.med_file_access_threshold == 5
|
||||
assert file_obs.high_file_access_threshold == 10
|
||||
|
||||
|
||||
def test_app_executions_threshold():
|
||||
"""Test that the NMNE thresholds are properly loaded in by observation."""
|
||||
game = load_config(BASIC_SWITCHED_NETWORK_CONFIG)
|
||||
|
||||
assert game.options.thresholds["app_executions"] is not None
|
||||
|
||||
# get application observation
|
||||
app_obs = game.agents["defender"].observation_manager.obs.components["NODES"].hosts[0].applications[0]
|
||||
assert app_obs.low_app_execution_threshold == 2
|
||||
assert app_obs.med_app_execution_threshold == 3
|
||||
assert app_obs.high_app_execution_threshold == 5
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/nodes_with_initial_files.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_node_file_system_from_config():
|
||||
"""Test that the appropriate files are instantiated in nodes when loaded from config."""
|
||||
game = load_config(BASIC_CONFIG)
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
assert client_1.software_manager.software.get("database-service") # database service should be installed
|
||||
assert client_1.file_system.get_file(folder_name="database", file_name="database.db") # database files should exist
|
||||
|
||||
assert client_1.software_manager.software.get("web-server") # web server should be installed
|
||||
assert client_1.file_system.get_file(folder_name="primaite", file_name="index.html") # web files should exist
|
||||
|
||||
client_2 = game.simulation.network.get_node_by_hostname("client_2")
|
||||
|
||||
# database service should not be installed
|
||||
assert client_2.software_manager.software.get("database-service") is None
|
||||
# database files should not exist
|
||||
assert client_2.file_system.get_file(folder_name="database", file_name="database.db") is None
|
||||
|
||||
# web server should not be installed
|
||||
assert client_2.software_manager.software.get("web-server") is None
|
||||
# web files should not exist
|
||||
assert client_2.file_system.get_file(folder_name="primaite", file_name="index.html") is None
|
||||
|
||||
empty_folder = client_2.file_system.get_folder(folder_name="empty_folder")
|
||||
assert empty_folder
|
||||
assert len(empty_folder.files) == 0 # should have no files
|
||||
|
||||
password_file = client_2.file_system.get_file(folder_name="root", file_name="passwords.txt")
|
||||
assert password_file # should exist
|
||||
assert password_file.file_type is FileType.TXT
|
||||
assert password_file.size == 663
|
||||
|
||||
downloads_folder = client_2.file_system.get_folder(folder_name="downloads")
|
||||
assert downloads_folder # downloads folder should exist
|
||||
|
||||
test_txt = downloads_folder.get_file(file_name="test.txt")
|
||||
assert test_txt # test.txt should exist
|
||||
assert test_txt.file_type is FileType.TXT
|
||||
|
||||
unknown_file_type = downloads_folder.get_file(file_name="another_file.pwtwoti")
|
||||
assert unknown_file_type # unknown_file_type should exist
|
||||
assert unknown_file_type.file_type is FileType.UNKNOWN
|
||||
@@ -49,7 +49,7 @@ class GigaSwitch(NetworkNode, discriminator="gigaswitch"):
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.hostname} Switch Ports"
|
||||
table.title = f"{self.config.hostname} Switch Ports"
|
||||
for port_num, port in self.network_interface.items():
|
||||
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
|
||||
print(table)
|
||||
|
||||
@@ -106,7 +106,6 @@ def test_remote_login_change_password(game_and_agent_fixture: Tuple[PrimaiteGame
|
||||
"username": "user123",
|
||||
"current_password": "password",
|
||||
"new_password": "different_password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -146,7 +145,6 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
|
||||
"username": "user123",
|
||||
"current_password": "password",
|
||||
"new_password": "different_password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
@@ -166,3 +164,55 @@ def test_change_password_logs_out_user(game_and_agent_fixture: Tuple[PrimaiteGam
|
||||
|
||||
assert server_1.file_system.get_folder("folder123") is None
|
||||
assert server_1.file_system.get_file("folder123", "doggo.pdf") is None
|
||||
|
||||
|
||||
def test_local_terminal(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
|
||||
game, agent = game_and_agent_fixture
|
||||
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
# create a new user account on server_1 that will be logged into remotely
|
||||
client_1_usm: UserManager = client_1.software_manager.software["user-manager"]
|
||||
client_1_usm.add_user("user123", "password", is_admin=True)
|
||||
|
||||
action = (
|
||||
"node-send-local-command",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"username": "user123",
|
||||
"password": "password",
|
||||
"command": ["file_system", "create", "file", "folder123", "doggo.pdf", False],
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert client_1.file_system.get_folder("folder123")
|
||||
assert client_1.file_system.get_file("folder123", "doggo.pdf")
|
||||
|
||||
# Change password
|
||||
action = (
|
||||
"node-account-change-password",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"username": "user123",
|
||||
"current_password": "password",
|
||||
"new_password": "different_password",
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
action = (
|
||||
"node-send-local-command",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"username": "user123",
|
||||
"password": "password",
|
||||
"command": ["file_system", "create", "file", "folder123", "cat.pdf", False],
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert client_1.file_system.get_file("folder123", "cat.pdf") is None
|
||||
client_1.session_manager.show()
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
|
||||
from primaite.utils.validation.port import Port, PORT_LOOKUP
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def game_and_agent_fixture(game_and_agent):
|
||||
"""Create a game with a simple agent that can be controlled by the tests."""
|
||||
game, agent = game_and_agent
|
||||
|
||||
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.start_up_duration = 3
|
||||
|
||||
return (game, agent)
|
||||
|
||||
|
||||
def test_user_account_add_user_action(game_and_agent_fixture):
|
||||
"""Tests the add user account action."""
|
||||
game, agent = game_and_agent_fixture
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
assert len(client_1.user_manager.users) == 1 # admin is created by default
|
||||
assert len(client_1.user_manager.admins) == 1
|
||||
|
||||
# add admin account
|
||||
action = (
|
||||
"node-account-add-user",
|
||||
{"node_name": "client_1", "username": "admin_2", "password": "e-tronic-boogaloo", "is_admin": True},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert len(client_1.user_manager.users) == 2 # new user added
|
||||
assert len(client_1.user_manager.admins) == 2
|
||||
|
||||
# add non admin account
|
||||
action = (
|
||||
"node-account-add-user",
|
||||
{"node_name": "client_1", "username": "leeroy.jenkins", "password": "no_plan_needed", "is_admin": False},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert len(client_1.user_manager.users) == 3 # new user added
|
||||
assert len(client_1.user_manager.admins) == 2
|
||||
|
||||
|
||||
def test_user_account_disable_user_action(game_and_agent_fixture):
|
||||
"""Tests the disable user account action."""
|
||||
game, agent = game_and_agent_fixture
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
client_1.user_manager.add_user(username="test", password="password", is_admin=True)
|
||||
assert len(client_1.user_manager.users) == 2 # new user added
|
||||
assert len(client_1.user_manager.admins) == 2
|
||||
|
||||
test_user = client_1.user_manager.users.get("test")
|
||||
assert test_user
|
||||
assert test_user.disabled is not True
|
||||
|
||||
# disable test account
|
||||
action = (
|
||||
"node-account-disable-user",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"username": "test",
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert test_user.disabled
|
||||
|
||||
|
||||
def test_user_account_change_password_action(game_and_agent_fixture):
|
||||
"""Tests the change password user account action."""
|
||||
game, agent = game_and_agent_fixture
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
client_1.user_manager.add_user(username="test", password="password", is_admin=True)
|
||||
|
||||
test_user = client_1.user_manager.users.get("test")
|
||||
assert test_user.password == "password"
|
||||
|
||||
# change account password
|
||||
action = (
|
||||
"node-account-change-password",
|
||||
{"node_name": "client_1", "username": "test", "current_password": "password", "new_password": "2Hard_2_Hack"},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
assert test_user.password == "2Hard_2_Hack"
|
||||
|
||||
|
||||
def test_user_account_create_terminal_action(game_and_agent_fixture):
|
||||
"""Tests that agents can use the terminal to create new users."""
|
||||
game, agent = game_and_agent_fixture
|
||||
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4)
|
||||
|
||||
server_1 = game.simulation.network.get_node_by_hostname("server_1")
|
||||
server_1_usm = server_1.software_manager.software["user-manager"]
|
||||
server_1_usm.add_user("user123", "password", is_admin=True)
|
||||
|
||||
action = (
|
||||
"node-session-remote-login",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"username": "user123",
|
||||
"password": "password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.history[-1].response.status == "success"
|
||||
|
||||
# Create a new user account via terminal.
|
||||
action = (
|
||||
"node-send-remote-command",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
"command": ["service", "user-manager", "add_user", "new_user", "new_pass", True],
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
new_user = server_1.user_manager.users.get("new_user")
|
||||
assert new_user
|
||||
assert new_user.password == "new_pass"
|
||||
assert new_user.disabled is not True
|
||||
|
||||
|
||||
def test_user_account_disable_terminal_action(game_and_agent_fixture):
|
||||
"""Tests that agents can use the terminal to disable users."""
|
||||
game, agent = game_and_agent_fixture
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=4)
|
||||
|
||||
server_1 = game.simulation.network.get_node_by_hostname("server_1")
|
||||
server_1_usm = server_1.software_manager.software["user-manager"]
|
||||
server_1_usm.add_user("user123", "password", is_admin=True)
|
||||
|
||||
action = (
|
||||
"node-session-remote-login",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"username": "user123",
|
||||
"password": "password",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
assert agent.history[-1].response.status == "success"
|
||||
|
||||
# Disable a user via terminal
|
||||
action = (
|
||||
"node-send-remote-command",
|
||||
{
|
||||
"node_name": "client_1",
|
||||
"remote_ip": str(server_1.network_interface[1].ip_address),
|
||||
"command": ["service", "user-manager", "disable_user", "user123"],
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
game.step()
|
||||
|
||||
new_user = server_1.user_manager.users.get("user123")
|
||||
assert new_user
|
||||
assert new_user.disabled is True
|
||||
@@ -44,6 +44,38 @@ def test_file_observation(simulation):
|
||||
assert observation_state.get("health_status") == 3 # corrupted
|
||||
|
||||
|
||||
def test_config_file_access_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert file_obs.high_file_access_threshold == 9
|
||||
assert file_obs.med_file_access_threshold == 6
|
||||
assert file_obs.low_file_access_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
FileObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 9, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
FileObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
file_system_requires_scan=True,
|
||||
thresholds={"file_access": {"low": 3, "medium": 9, "high": 9}},
|
||||
)
|
||||
|
||||
|
||||
def test_folder_observation(simulation):
|
||||
"""Test the folder observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
@@ -77,6 +77,14 @@ def test_nic(simulation):
|
||||
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.config.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
# The Simulation object created by the fixture also creates the
|
||||
# NICObservation class with the NICObservation.capture_nmnme class variable
|
||||
# set to False. Under normal (non-test) circumstances this class variable
|
||||
# is set from a config file such as data_manipulation.yaml. So although
|
||||
# capture_nmne is set to True in the NetworkInterface class it's still False
|
||||
# in the NICObservation class so we set it now.
|
||||
nic_obs.capture_nmne = True
|
||||
|
||||
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
|
||||
nmne_config = {
|
||||
"capture_nmne": True, # Enable the capture of MNEs
|
||||
@@ -115,14 +123,11 @@ def test_nic_categories(simulation):
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Feature not implemented yet")
|
||||
def test_config_nic_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
where=["network", "nodes", pc.config.hostname, "NICs", 1],
|
||||
thresholds={"nmne": {"low": 3, "medium": 6, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
@@ -133,20 +138,16 @@ def test_config_nic_categories(simulation):
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
where=["network", "nodes", pc.config.hostname, "NICs", 1],
|
||||
thresholds={"nmne": {"low": 9, "medium": 6, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
high_nmne_threshold=9,
|
||||
where=["network", "nodes", pc.config.hostname, "NICs", 1],
|
||||
thresholds={"nmne": {"low": 3, "medium": 9, "high": 9}},
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ def test_host_observation(simulation):
|
||||
folders=[],
|
||||
network_interfaces=[],
|
||||
file_system_requires_scan=True,
|
||||
services_requires_scan=True,
|
||||
applications_requires_scan=True,
|
||||
include_users=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
DATA_MANIPULATION_CONFIG = TEST_ASSETS_ROOT / "configs" / "data_manipulation.yaml"
|
||||
|
||||
|
||||
def test_obs_data_in_log_file():
|
||||
"""Create a log file of AgentHistoryItems and check observation data is
|
||||
included. Assumes that data_manipulation.yaml has an agent labelled
|
||||
'defender' with a non-null observation space.
|
||||
The log file will be in:
|
||||
primaite/VERSION/sessions/YYYY-MM-DD/HH-MM-SS/agent_actions
|
||||
"""
|
||||
env = PrimaiteGymEnv(DATA_MANIPULATION_CONFIG)
|
||||
env.reset()
|
||||
for _ in range(10):
|
||||
env.step(0)
|
||||
env.reset()
|
||||
io = PrimaiteIO()
|
||||
path = io.generate_agent_actions_save_path(episode=1)
|
||||
with open(path, "r") as f:
|
||||
j = json.load(f)
|
||||
|
||||
assert type(j["0"]["defender"]["observation"]) == dict
|
||||
@@ -29,7 +29,9 @@ def test_service_observation(simulation):
|
||||
ntp_server = pc.software_manager.software.get("ntp-server")
|
||||
assert ntp_server
|
||||
|
||||
service_obs = ServiceObservation(where=["network", "nodes", pc.config.hostname, "services", "ntp-server"])
|
||||
service_obs = ServiceObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "services", "ntp-server"], services_requires_scan=True
|
||||
)
|
||||
|
||||
assert service_obs.space["operating_status"] == spaces.Discrete(7)
|
||||
assert service_obs.space["health_status"] == spaces.Discrete(5)
|
||||
@@ -54,7 +56,9 @@ def test_application_observation(simulation):
|
||||
web_browser: WebBrowser = pc.software_manager.software.get("web-browser")
|
||||
assert web_browser
|
||||
|
||||
app_obs = ApplicationObservation(where=["network", "nodes", pc.config.hostname, "applications", "web-browser"])
|
||||
app_obs = ApplicationObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "applications", "web-browser"], applications_requires_scan=True
|
||||
)
|
||||
|
||||
web_browser.close()
|
||||
observation_state = app_obs.observe(simulation.describe_state())
|
||||
@@ -69,3 +73,33 @@ def test_application_observation(simulation):
|
||||
assert observation_state.get("health_status") == 1
|
||||
assert observation_state.get("operating_status") == 1 # running
|
||||
assert observation_state.get("num_executions") == 1
|
||||
|
||||
|
||||
def test_application_executions_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
app_obs = ApplicationObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"],
|
||||
applications_requires_scan=False,
|
||||
thresholds={"app_executions": {"low": 3, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
assert app_obs.high_app_execution_threshold == 9
|
||||
assert app_obs.med_app_execution_threshold == 6
|
||||
assert app_obs.low_app_execution_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
ApplicationObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"],
|
||||
applications_requires_scan=False,
|
||||
thresholds={"app_executions": {"low": 9, "medium": 6, "high": 9}},
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
ApplicationObservation(
|
||||
where=["network", "nodes", pc.config.hostname, "applications", "WebBrowser"],
|
||||
applications_requires_scan=False,
|
||||
thresholds={"app_executions": {"low": 3, "medium": 9, "high": 9}},
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.interface import AgentHistoryItem
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -33,6 +34,11 @@ def test_rng_seed_set(create_env):
|
||||
|
||||
assert a == b
|
||||
|
||||
# Check that seed log file was created.
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "r") as file:
|
||||
assert file
|
||||
|
||||
|
||||
def test_rng_seed_unset(create_env):
|
||||
"""Test with no RNG seed."""
|
||||
@@ -48,3 +54,19 @@ def test_rng_seed_unset(create_env):
|
||||
b = [item.timestep for item in env.game.agents["client_2_green_user"].history if item.action != "do-nothing"]
|
||||
|
||||
assert a != b
|
||||
|
||||
|
||||
def test_for_generated_seed():
|
||||
"""
|
||||
Show that setting generate_seed_value to true producess a valid seed.
|
||||
"""
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
cfg["game"]["generate_seed_value"] = True
|
||||
PrimaiteGymEnv(env_config=cfg)
|
||||
path = SIM_OUTPUT.path / "seed.log"
|
||||
with open(path, "r") as file:
|
||||
data = file.read()
|
||||
|
||||
assert data.split(" ")[3] != None
|
||||
|
||||
@@ -22,6 +22,7 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import Router
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
@@ -107,7 +108,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
|
||||
"""
|
||||
Test that the RouterACLAddRuleAction can form a request and that it is accepted by the simulation.
|
||||
|
||||
The acl starts off with 4 rules, and we add a rule, and check that the acl now has 5 rules.
|
||||
The ACL starts off with 4 rules, and we add a rule, and check that the ACL now has 5 rules.
|
||||
"""
|
||||
game, agent = game_and_agent
|
||||
|
||||
@@ -164,11 +165,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox
|
||||
},
|
||||
)
|
||||
agent.store_action(action)
|
||||
print(agent.most_recent_action)
|
||||
game.step()
|
||||
print(agent.most_recent_action)
|
||||
|
||||
# 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2
|
||||
print(router.acl.show())
|
||||
assert router.acl.num_rules == 6
|
||||
assert server_1.ping("10.0.2.3") # Can ping server_2
|
||||
|
||||
@@ -180,7 +179,8 @@ def test_router_acl_removerule_integration(game_and_agent: Tuple[PrimaiteGame, P
|
||||
# 1: Check that http traffic is going across the network nicely.
|
||||
client_1 = game.simulation.network.get_node_by_hostname("client_1")
|
||||
server_1 = game.simulation.network.get_node_by_hostname("server_1")
|
||||
router = game.simulation.network.get_node_by_hostname("router")
|
||||
router: Router = game.simulation.network.get_node_by_hostname("router")
|
||||
assert router.acl.num_rules == 4
|
||||
|
||||
browser: WebBrowser = client_1.software_manager.software.get("web-browser")
|
||||
browser.run()
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from itertools import product
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
@@ -277,3 +283,19 @@ def test_capture_nmne_observations(uc2_network: Network):
|
||||
assert web_nic_obs["outbound"] == expected_nmne
|
||||
assert db_nic_obs["inbound"] == expected_nmne
|
||||
uc2_network.apply_timestep(timestep=0)
|
||||
|
||||
|
||||
def test_nmne_parameter_settings():
|
||||
"""
|
||||
Check that the four permutations of the values of capture_nmne and
|
||||
include_nmne work as expected.
|
||||
"""
|
||||
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
DEFENDER = 3
|
||||
for capture, include in product([True, False], [True, False]):
|
||||
cfg["simulation"]["network"]["nmne_config"]["capture_nmne"] = capture
|
||||
cfg["agents"][DEFENDER]["observation_space"]["options"]["components"][0]["options"]["include_nmne"] = include
|
||||
PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
from primaite.simulator.network.hardware.nodes.network.router import RouterARP
|
||||
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router, RouterARP
|
||||
from primaite.simulator.system.services.arp.arp import ARP
|
||||
from primaite.utils.validation.port import PORT_LOOKUP
|
||||
from tests.integration_tests.network.test_routing import multi_hop_network
|
||||
|
||||
|
||||
@@ -48,3 +49,19 @@ def test_arp_fails_for_network_address_between_routers(multi_hop_network):
|
||||
actual_result = router_1_arp.get_arp_cache_mac_address(router_1.network_interface[1].ip_network.network_address)
|
||||
|
||||
assert actual_result == expected_result
|
||||
|
||||
|
||||
def test_arp_not_affected_by_acl(multi_hop_network):
|
||||
pc_a = multi_hop_network.get_node_by_hostname("pc_a")
|
||||
router_1: Router = multi_hop_network.get_node_by_hostname("router_1")
|
||||
|
||||
# Add explicit rule to block ARP traffic. This shouldn't actually stop ARP traffic
|
||||
# as it operates a different layer within the network.
|
||||
router_1.acl.add_rule(action=ACLAction.DENY, src_port=PORT_LOOKUP["ARP"], dst_port=PORT_LOOKUP["ARP"], position=23)
|
||||
|
||||
pc_a_arp: ARP = pc_a.software_manager.arp
|
||||
|
||||
expected_result = router_1.network_interface[2].mac_address
|
||||
actual_result = pc_a_arp.get_arp_cache_mac_address(router_1.network_interface[2].ip_address)
|
||||
|
||||
assert actual_result == expected_result
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.observations import ApplicationObservation, ObservationManager, ServiceObservation
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
|
||||
@@ -136,3 +137,227 @@ class TestFileSystemRequiresScan:
|
||||
[], files=[], num_files=0, include_num_access=False, file_system_requires_scan=True
|
||||
)
|
||||
assert obs_requiring_scan.observe(folder_state)["health_status"] == 1
|
||||
|
||||
|
||||
class TestServicesRequiresScan:
|
||||
@pytest.mark.parametrize(
|
||||
("yaml_option_string", "expected_val"),
|
||||
(
|
||||
("services_requires_scan: true", True),
|
||||
("services_requires_scan: false", False),
|
||||
(" ", True),
|
||||
),
|
||||
)
|
||||
def test_obs_config(self, yaml_option_string, expected_val):
|
||||
"""Check that the default behaviour is to set service_requires_scan to True."""
|
||||
obs_cfg_yaml = f"""
|
||||
type: custom
|
||||
options:
|
||||
components:
|
||||
- type: nodes
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: web-server
|
||||
- service_name: dns-client
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
services:
|
||||
- service_name: ftp-server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 3
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
{yaml_option_string}
|
||||
include_nmne: true
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: links
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: none
|
||||
label: ICS
|
||||
options: {{}}
|
||||
|
||||
"""
|
||||
|
||||
cfg = yaml.safe_load(obs_cfg_yaml)
|
||||
manager = ObservationManager.from_config(cfg)
|
||||
|
||||
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
|
||||
for i, host in enumerate(hosts):
|
||||
services: List[ServiceObservation] = host.services
|
||||
for j, service in enumerate(services):
|
||||
val = service.services_requires_scan
|
||||
print(f"host {i} service {j} {val}")
|
||||
assert val == expected_val # Make sure services require scan by default
|
||||
|
||||
def test_services_requires_scan(self):
|
||||
state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1}
|
||||
|
||||
obs_requiring_scan = ServiceObservation([], services_requires_scan=True)
|
||||
assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value
|
||||
|
||||
obs_not_requiring_scan = ServiceObservation([], services_requires_scan=False)
|
||||
assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value
|
||||
|
||||
|
||||
class TestApplicationsRequiresScan:
|
||||
@pytest.mark.parametrize(
|
||||
("yaml_option_string", "expected_val"),
|
||||
(
|
||||
("applications_requires_scan: true", True),
|
||||
("applications_requires_scan: false", False),
|
||||
(" ", True),
|
||||
),
|
||||
)
|
||||
def test_obs_config(self, yaml_option_string, expected_val):
|
||||
"""Check that the default behaviour is to set applications_requires_scan to True."""
|
||||
obs_cfg_yaml = f"""
|
||||
type: custom
|
||||
options:
|
||||
components:
|
||||
- type: nodes
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
applications:
|
||||
- application_name: web-browser
|
||||
- hostname: client_2
|
||||
applications:
|
||||
- application_name: web-browser
|
||||
- application_name: database-client
|
||||
num_services: 0
|
||||
num_applications: 3
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
{yaml_option_string}
|
||||
include_nmne: true
|
||||
monitored_traffic:
|
||||
icmp:
|
||||
- NONE
|
||||
tcp:
|
||||
- DNS
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: links
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: none
|
||||
label: ICS
|
||||
options: {{}}
|
||||
|
||||
"""
|
||||
|
||||
cfg = yaml.safe_load(obs_cfg_yaml)
|
||||
manager = ObservationManager.from_config(cfg)
|
||||
|
||||
hosts: List[HostObservation] = manager.obs.components["NODES"].hosts
|
||||
for i, host in enumerate(hosts):
|
||||
services: List[ServiceObservation] = host.services
|
||||
for j, service in enumerate(services):
|
||||
val = service.services_requires_scan
|
||||
print(f"host {i} service {j} {val}")
|
||||
assert val == expected_val # Make sure applications require scan by default
|
||||
|
||||
def test_applications_requires_scan(self):
|
||||
state = {"health_state_actual": 3, "health_state_visible": 1, "operating_state": 1, "num_executions": 1}
|
||||
|
||||
obs_requiring_scan = ApplicationObservation([], applications_requires_scan=True)
|
||||
assert obs_requiring_scan.observe(state)["health_status"] == 1 # should be visible value
|
||||
|
||||
obs_not_requiring_scan = ApplicationObservation([], applications_requires_scan=False)
|
||||
assert obs_not_requiring_scan.observe(state)["health_status"] == 3 # should be actual value
|
||||
|
||||
@@ -73,7 +73,7 @@ def test_ftp_should_not_process_commands_if_service_not_running(ftp_client):
|
||||
assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
|
||||
|
||||
|
||||
def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
|
||||
def test_ftp_tries_to_send_file__that_does_not_exist(ftp_client):
|
||||
"""Method send_file should return false if no file to send."""
|
||||
assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
@@ -442,3 +443,59 @@ def test_terminal_connection_timeout(basic_network):
|
||||
assert len(computer_b.user_session_manager.remote_sessions) == 0
|
||||
|
||||
assert not remote_connection.is_active
|
||||
|
||||
|
||||
def test_terminal_last_response_updates(basic_network):
|
||||
"""Test that the _last_response within Terminal correctly updates."""
|
||||
network: Network = basic_network
|
||||
computer_a: Computer = network.get_node_by_hostname("node_a")
|
||||
terminal_a: Terminal = computer_a.software_manager.software.get("terminal")
|
||||
computer_b: Computer = network.get_node_by_hostname("node_b")
|
||||
|
||||
assert terminal_a.last_response is None
|
||||
|
||||
remote_connection = terminal_a.login(username="admin", password="admin", ip_address="192.168.0.11")
|
||||
|
||||
# Last response should be a successful logon
|
||||
assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "Login Successful"})
|
||||
|
||||
remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"])
|
||||
|
||||
# Last response should now update following successful install
|
||||
assert terminal_a.last_response == RequestResponse(status="success", data={})
|
||||
|
||||
remote_connection.execute(command=["software_manager", "application", "install", "ransomware-script"])
|
||||
|
||||
# Last response should now update to success, but with supplied reason.
|
||||
assert terminal_a.last_response == RequestResponse(status="success", data={"reason": "already installed"})
|
||||
|
||||
remote_connection.execute(command=["file_system", "create", "file", "folder123", "doggo.pdf", False])
|
||||
|
||||
# Check file was created.
|
||||
assert computer_b.file_system.access_file(folder_name="folder123", file_name="doggo.pdf")
|
||||
|
||||
# Last response should be confirmation of file creation.
|
||||
assert terminal_a.last_response == RequestResponse(
|
||||
status="success",
|
||||
data={"file_name": "doggo.pdf", "folder_name": "folder123", "file_type": "PDF", "file_size": 102400},
|
||||
)
|
||||
|
||||
remote_connection.execute(
|
||||
command=[
|
||||
"service",
|
||||
"ftp-client",
|
||||
"send",
|
||||
{
|
||||
"dest_ip_address": "192.168.0.2",
|
||||
"src_folder": "folder123",
|
||||
"src_file_name": "cat.pdf",
|
||||
"dest_folder": "root",
|
||||
"dest_file_name": "cat.pdf",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert terminal_a.last_response == RequestResponse(
|
||||
status="failure",
|
||||
data={"reason": "Unable to locate given file on local file system. Perhaps given options are invalid?"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user